Unverified Commit f9bc748f by Tianqi Chen Committed by GitHub

[DEPRECATION] Remove NNVM compiler (#4571)

* Remove NNVM compiler
parent 9bf2bee6
...@@ -118,9 +118,8 @@ else(MSVC) ...@@ -118,9 +118,8 @@ else(MSVC)
endif(MSVC) endif(MSVC)
# add source group # add source group
FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "nnvm/src/*.cc") FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc")
FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h" FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h")
"nnvm/src/*.h" "nnvm/include/*.h")
assign_source_group("Source" ${GROUP_SOURCE}) assign_source_group("Source" ${GROUP_SOURCE})
assign_source_group("Include" ${GROUP_INCLUDE}) assign_source_group("Include" ${GROUP_INCLUDE})
...@@ -174,13 +173,6 @@ if(NOT MSVC) ...@@ -174,13 +173,6 @@ if(NOT MSVC)
list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS}) list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS})
endif() endif()
file(GLOB_RECURSE NNVM_COMPILER_SRCS
nnvm/src/c_api/*.cc
nnvm/src/core/*.cc
nnvm/src/pass/*.cc
nnvm/src/compiler/*.cc
nnvm/src/top/*.cc
)
file(GLOB TOPI_SRCS file(GLOB TOPI_SRCS
topi/src/*.cc topi/src/*.cc
...@@ -294,7 +286,6 @@ if(NOT USE_SGX STREQUAL "OFF") ...@@ -294,7 +286,6 @@ if(NOT USE_SGX STREQUAL "OFF")
add_dependencies(tvm_runtime sgx_edl tvm_t) add_dependencies(tvm_runtime sgx_edl tvm_t)
install(TARGETS tvm_t ARCHIVE DESTINATION lib${LIB_SUFFIX}) install(TARGETS tvm_t ARCHIVE DESTINATION lib${LIB_SUFFIX})
endif() endif()
add_library(nnvm_compiler SHARED ${NNVM_COMPILER_SRCS})
if(USE_THREADS) if(USE_THREADS)
message(STATUS "Build with thread support...") message(STATUS "Build with thread support...")
...@@ -304,13 +295,11 @@ if(USE_THREADS) ...@@ -304,13 +295,11 @@ if(USE_THREADS)
target_link_libraries(tvm Threads::Threads) target_link_libraries(tvm Threads::Threads)
target_link_libraries(tvm_topi Threads::Threads) target_link_libraries(tvm_topi Threads::Threads)
target_link_libraries(tvm_runtime Threads::Threads) target_link_libraries(tvm_runtime Threads::Threads)
target_link_libraries(nnvm_compiler Threads::Threads)
endif(USE_THREADS) endif(USE_THREADS)
target_link_libraries(tvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS}) target_link_libraries(tvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS})
target_link_libraries(tvm_topi tvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS}) target_link_libraries(tvm_topi tvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS})
target_link_libraries(tvm_runtime ${TVM_RUNTIME_LINKER_LIBS}) target_link_libraries(tvm_runtime ${TVM_RUNTIME_LINKER_LIBS})
target_link_libraries(nnvm_compiler tvm)
if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(HIDE_SYMBOLS_LINKER_FLAGS "-Wl,--exclude-libs,ALL") set(HIDE_SYMBOLS_LINKER_FLAGS "-Wl,--exclude-libs,ALL")
...@@ -320,7 +309,6 @@ if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") ...@@ -320,7 +309,6 @@ if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
target_link_libraries(tvm ${HIDE_SYMBOLS_LINKER_FLAGS}) target_link_libraries(tvm ${HIDE_SYMBOLS_LINKER_FLAGS})
target_link_libraries(tvm_topi ${HIDE_SYMBOLS_LINKER_FLAGS}) target_link_libraries(tvm_topi ${HIDE_SYMBOLS_LINKER_FLAGS})
target_link_libraries(tvm_runtime ${HIDE_SYMBOLS_LINKER_FLAGS}) target_link_libraries(tvm_runtime ${HIDE_SYMBOLS_LINKER_FLAGS})
target_link_libraries(nnvm_compiler ${HIDE_SYMBOLS_LINKER_FLAGS})
endif() endif()
# Related headers # Related headers
...@@ -330,10 +318,7 @@ target_include_directories( ...@@ -330,10 +318,7 @@ target_include_directories(
target_include_directories( target_include_directories(
tvm_topi tvm_topi
PUBLIC "topi/include") PUBLIC "topi/include")
target_include_directories(
nnvm_compiler
PUBLIC "nnvm/include"
PUBLIC "topi/include")
# Tests # Tests
set(TEST_EXECS "") set(TEST_EXECS "")
...@@ -372,7 +357,6 @@ add_custom_target(runtime DEPENDS tvm_runtime) ...@@ -372,7 +357,6 @@ add_custom_target(runtime DEPENDS tvm_runtime)
install(TARGETS tvm DESTINATION lib${LIB_SUFFIX}) install(TARGETS tvm DESTINATION lib${LIB_SUFFIX})
install(TARGETS tvm_topi DESTINATION lib${LIB_SUFFIX}) install(TARGETS tvm_topi DESTINATION lib${LIB_SUFFIX})
install(TARGETS tvm_runtime DESTINATION lib${LIB_SUFFIX}) install(TARGETS tvm_runtime DESTINATION lib${LIB_SUFFIX})
install(TARGETS nnvm_compiler DESTINATION lib${LIB_SUFFIX})
if (INSTALL_DEV) if (INSTALL_DEV)
install( install(
...@@ -395,11 +379,6 @@ if (INSTALL_DEV) ...@@ -395,11 +379,6 @@ if (INSTALL_DEV)
FILES_MATCHING FILES_MATCHING
PATTERN "*.h" PATTERN "*.h"
) )
install(
DIRECTORY "nnvm/include/." DESTINATION "include"
FILES_MATCHING
PATTERN "*.h"
)
else(INSTALL_DEV) else(INSTALL_DEV)
install( install(
DIRECTORY "include/tvm/runtime/." DESTINATION "include/tvm/runtime" DIRECTORY "include/tvm/runtime/." DESTINATION "include/tvm/runtime"
...@@ -412,5 +391,4 @@ endif(INSTALL_DEV) ...@@ -412,5 +391,4 @@ endif(INSTALL_DEV)
if(MSVC) if(MSVC)
target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS) target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS) target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
target_compile_definitions(nnvm_compiler PRIVATE -DNNVM_EXPORTS)
endif() endif()
...@@ -69,14 +69,12 @@ build/libtvm_web_runtime.js: build/libtvm_web_runtime.bc ...@@ -69,14 +69,12 @@ build/libtvm_web_runtime.js: build/libtvm_web_runtime.bc
cpplint: cpplint:
python3 3rdparty/dmlc-core/scripts/lint.py vta cpp vta/include vta/src python3 3rdparty/dmlc-core/scripts/lint.py vta cpp vta/include vta/src
python3 3rdparty/dmlc-core/scripts/lint.py topi cpp topi/include; python3 3rdparty/dmlc-core/scripts/lint.py topi cpp topi/include;
python3 3rdparty/dmlc-core/scripts/lint.py nnvm cpp nnvm/include nnvm/src;
python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp include src \ python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp include src \
examples/extension/src examples/graph_executor/src examples/extension/src examples/graph_executor/src
pylint: pylint:
python3 -m pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc python3 -m pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
python3 -m pylint topi/python/topi --rcfile=$(ROOTDIR)/tests/lint/pylintrc python3 -m pylint topi/python/topi --rcfile=$(ROOTDIR)/tests/lint/pylintrc
python3 -m pylint nnvm/python/nnvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
python3 -m pylint vta/python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc python3 -m pylint vta/python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
jnilint: jnilint:
......
...@@ -26,7 +26,3 @@ cd .. ...@@ -26,7 +26,3 @@ cd ..
cd topi/python cd topi/python
$PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt $PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt
cd ../.. cd ../..
cd nnvm/python
$PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt
cd ../..
...@@ -48,7 +48,6 @@ test: ...@@ -48,7 +48,6 @@ test:
imports: imports:
- tvm - tvm
- topi - topi
- nnvm
requires: requires:
- pytest - pytest
- scipy - scipy
......
...@@ -70,5 +70,5 @@ RUN cd /usr && \ ...@@ -70,5 +70,5 @@ RUN cd /usr && \
make -j10 make -j10
# Environment variables # Environment variables
ENV PYTHONPATH=/usr/tvm/python:/usr/tvm/topi/python:/usr/tvm/nnvm/python/:/usr/tvm/vta/python:${PYTHONPATH} ENV PYTHONPATH=/usr/tvm/python:/usr/tvm/topi/python:/usr/tvm/vta/python:${PYTHONPATH}
ENV ANDROID_HOME=/opt/android-sdk-linux/ ENV ANDROID_HOME=/opt/android-sdk-linux/
...@@ -30,4 +30,4 @@ COPY install/install_tvm_cpu.sh /install/install_tvm_cpu.sh ...@@ -30,4 +30,4 @@ COPY install/install_tvm_cpu.sh /install/install_tvm_cpu.sh
RUN bash /install/install_tvm_cpu.sh RUN bash /install/install_tvm_cpu.sh
# Environment variables # Environment variables
ENV PYTHONPATH=/usr/tvm/python:/usr/tvm/topi/python:/usr/tvm/nnvm/python/:/usr/tvm/vta/python:${PYTHONPATH} ENV PYTHONPATH=/usr/tvm/python:/usr/tvm/topi/python:/usr/tvm/vta/python:${PYTHONPATH}
...@@ -28,7 +28,7 @@ COPY install/install_tvm_gpu.sh /install/install_tvm_gpu.sh ...@@ -28,7 +28,7 @@ COPY install/install_tvm_gpu.sh /install/install_tvm_gpu.sh
RUN bash /install/install_tvm_gpu.sh RUN bash /install/install_tvm_gpu.sh
# Environment variables # Environment variables
ENV PYTHONPATH=/usr/tvm/python:/usr/tvm/topi/python:/usr/tvm/nnvm/python/:/usr/tvm/vta/python:${PYTHONPATH} ENV PYTHONPATH=/usr/tvm/python:/usr/tvm/topi/python:/usr/tvm/vta/python:${PYTHONPATH}
ENV PATH=/usr/local/nvidia/bin:${PATH} ENV PATH=/usr/local/nvidia/bin:${PATH}
ENV PATH=/usr/local/cuda/bin:${PATH} ENV PATH=/usr/local/cuda/bin:${PATH}
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/nvidia/lib64:${LD_LIBRARY_PATH} ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/nvidia/lib64:${LD_LIBRARY_PATH}
...@@ -76,7 +76,6 @@ RUN mkdir -p ${TVM_BUILD_DIR} && \ ...@@ -76,7 +76,6 @@ RUN mkdir -p ${TVM_BUILD_DIR} && \
make -j6 make -j6
RUN echo "Building Python package" RUN echo "Building Python package"
ENV PYTHONPATH=${TVM_HOME}/python:${TVM_HOME}/topi/python:${TVM_HOME}/nnvm/python:${PYTHONPATH} ENV PYTHONPATH=${TVM_HOME}/python:${TVM_HOME}/topi/python:${PYTHONPATH}
RUN cd ${TVM_HOME}/python && python3 setup.py install --user RUN cd ${TVM_HOME}/python && python3 setup.py install --user
RUN cd ${TVM_HOME}/topi/python && python3 setup.py install --user RUN cd ${TVM_HOME}/topi/python && python3 setup.py install --user
RUN cd ${TVM_HOME}/nnvm/python && python3 setup.py install --user
...@@ -32,7 +32,7 @@ import java.lang.reflect.Method; ...@@ -32,7 +32,7 @@ import java.lang.reflect.Method;
public class GraphRuntime { public class GraphRuntime {
/** /**
* Create a runtime executor module given a graph and module. * Create a runtime executor module given a graph and module.
* @param graphJson The graph deployed in json format output by nnvm graph. * @param graphJson The graph deployed in json format output by compiler.
* @param libmod The module of the corresponding function. * @param libmod The module of the corresponding function.
* @param ctx The local or remote context to deploy the module. * @param ctx The local or remote context to deploy the module.
* @return Runtime graph module that can be used to execute the graph. * @return Runtime graph module that can be used to execute the graph.
......
...@@ -30,7 +30,6 @@ TVMPATH = .. ...@@ -30,7 +30,6 @@ TVMPATH = ..
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC
CFLAGS += -I$(TVMPATH)/include -I$(TVMPATH)/3rdparty/dlpack/include -I$(TVMPATH)/3rdparty/HalideIR/src -I$(TVMPATH)/topi/include
ifdef DMLC_CORE_PATH ifdef DMLC_CORE_PATH
CFLAGS += -I$(DMLC_CORE_PATH)/include CFLAGS += -I$(DMLC_CORE_PATH)/include
...@@ -66,7 +65,7 @@ else ...@@ -66,7 +65,7 @@ else
NO_WHOLE_ARCH= --no-whole-archive NO_WHOLE_ARCH= --no-whole-archive
endif endif
all: lib/libnnvm.a lib/libnnvm_compiler.$(SHARED_LIBRARY_SUFFIX) all: lib/libnnvm.a lib/libnnvm.$(SHARED_LIBRARY_SUFFIX)
SRC = $(wildcard src/*.cc src/c_api/*.cc src/core/*.cc src/pass/*.cc) SRC = $(wildcard src/*.cc src/c_api/*.cc src/core/*.cc src/pass/*.cc)
SRC_COMPILER = $(wildcard src/top/*/*.cc wildcard src/top/vision/*/*.cc src/compiler/*.cc src/compiler/*/*.cc) SRC_COMPILER = $(wildcard src/top/*/*.cc wildcard src/top/vision/*/*.cc src/compiler/*.cc src/compiler/*/*.cc)
...@@ -87,7 +86,7 @@ lib/libnnvm.a: $(ALL_DEP) ...@@ -87,7 +86,7 @@ lib/libnnvm.a: $(ALL_DEP)
@mkdir -p $(@D) @mkdir -p $(@D)
$(AR) crv $@ $(filter %.o, $?) $(AR) crv $@ $(filter %.o, $?)
lib/libnnvm_compiler.$(SHARED_LIBRARY_SUFFIX): lib/libnnvm.a ${TOP_OBJ} lib/libnnvm.$(SHARED_LIBRARY_SUFFIX): lib/libnnvm.a ${TOP_OBJ}
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS) -Wl,${WHOLE_ARCH} lib/libnnvm.a -Wl,${NO_WHOLE_ARCH} $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS) -Wl,${WHOLE_ARCH} lib/libnnvm.a -Wl,${NO_WHOLE_ARCH}
......
...@@ -15,38 +15,8 @@ ...@@ -15,38 +15,8 @@
<!--- specific language governing permissions and limitations --> <!--- specific language governing permissions and limitations -->
<!--- under the License. --> <!--- under the License. -->
# NNVM Compiler Module of TVM Stack # NNVM
```python NNVM is a graph level IR for neural networks.
import tvm We are moving towards Relay IR, a better unified IR that support wider range of programs.
from tvm.contrib import graph_runtime, rpc Please use relay instead.
import nnvm.frontend
import nnvm.compiler
# GET model from frameworks
# change xyz to supported framework name.
graph, params = nnvm.frontend.from_xyz(...)
# OPTIMIZE and COMPILE the graph to get a deployable module
# target can be "opencl", "llvm", "metal" or any target supported by tvm
target = "cuda"
graph, lib, params = nnvm.compiler.build(graph, target, {"data", data_shape}, params=params)
# DEPLOY and run on gpu(0)
module = graph_runtime.create(graph, lib, tvm.gpu(0))
module.set_input(**params)
module.run(data=data_array)
output = tvm.nd.empty(out_shape, ctx=tvm.gpu(0))
module.get_output(0, output)
# DEPLOY to REMOTE mobile/rasp/browser with minimum tvm rpc runtime
# useful for quick experiments on mobile devices
remote = rpc.connect(remote_host, remote_port)
lib.export_library("mylib.so")
remote.upload("mylib.so")
rlib = rpc.load_module("mylib.so")
# run on remote device
rmodule = graph_runtime.create(graph, rlib, remote.gpu(0))
rmodule.set_input(**params)
rmodule.run()
```
...@@ -46,6 +46,24 @@ using dmlc::get; ...@@ -46,6 +46,24 @@ using dmlc::get;
/*!\brief "unsafe" getter function of any type */ /*!\brief "unsafe" getter function of any type */
using dmlc::unsafe_get; using dmlc::unsafe_get;
enum TypeFlag {
kFloat32 = 0,
kFloat64 = 1,
kFloat16 = 2,
kUint8 = 3,
kInt32 = 4,
kInt8 = 5,
kInt64 = 6,
// kBool = 7,
// 7 is reserved for kBool, in order to keep consistency with MXNet TypeFlag defined in
// https://github.com/apache/incubator-mxnet/blob/master/3rdparty/mshadow/mshadow/base.h#L314
kInt16 = 8,
kUint16 = 9,
kUint32 = 10,
kUint64 = 11,
kBfloat16 = 12,
};
} // namespace nnvm } // namespace nnvm
// describe op registration point // describe op registration point
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file nnvm/compiler/op_attr_types.h
* \brief The Expr and related elements in DataFlow construction.
*/
#ifndef NNVM_COMPILER_OP_ATTR_TYPES_H_
#define NNVM_COMPILER_OP_ATTR_TYPES_H_
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/graph.h>
#include <vector>
#include <string>
#include "packed_func_ext.h"
namespace nnvm {
namespace compiler {
using ::tvm::Array;
using ::tvm::Tensor;
using ::tvm::Schedule;
/*! \brief operator pattern used in graph fusion */
enum OpPatternKind {
// Elementwise operation
kElemWise = 0,
// Broadcasting operator, can always map output axis to the input in order.
// for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
// Note that the axis need to be in order so transpose is not a bcast operator.
kBroadcast = 1,
// Injective operator, can always injectively map output axis to a single input axis.
// All injective operator can still be safely fused to injective and reduction.
kInjective = 2,
// Communicative reduction operator.
kCommReduce = 3,
// Complex operation, can still fuse elemwise operations into its output.
// but cannot chain another complex op
kOutEWiseFusable = 4,
// Opaque operation, cannot fuse anything.
kOpaque = 8
};
/*! \brief the operator pattern */
using TOpPattern = int;
/*!
* \brief Computation description interface
* \param attrs The attribute of the node.
* \param inputs The input tensors(placeholders)
* \param out_info Tensors holding shape/type information about output,
& these are always placeholders.
* \return The output description of the tensor.
*/
using FTVMCompute = std::function<
Array<Tensor>(const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info)>;
/*!
* \brief Build the computation schedule for
* op whose root is at current op.
* \param attrs The attribute of the node.
* \param outs The output tensors.
* \param target The build target.
* \return schedule The computation schedule.
*/
using FTVMSchedule = std::function<
Schedule(const NodeAttrs& attrs,
const Array<Tensor>& outs,
const std::string& target)>;
/*!
* \brief Modify the op node to alter its input layout.
* it is invoked in AlterOpLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos The inferred shape and dtype of the inputs.
* \param ret The replaced operator.
* \return Whether to replace current operator.
*/
using FTVMAlterOpLayout = std::function<
bool(const NodeAttrs& attrs,
const Symbol& inputs,
const Array<Tensor>& tinfos,
Symbol* ret)>;
/*!
* \brief Transform from normal operator to vectorized operator
* \param node The source node.
* \return Transformed vectorized op.
*/
using FTVMVectorizedOp = std::function<nnvm::NodePtr (const nnvm::Node* node)>;
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_OP_ATTR_TYPES_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file nnvm/compiler/packed_func_ext.h
* \brief Extension to enable packed functionn for nnvm types
*/
#ifndef NNVM_COMPILER_PACKED_FUNC_EXT_H_
#define NNVM_COMPILER_PACKED_FUNC_EXT_H_
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <nnvm/graph.h>
#include <nnvm/symbolic.h>
#include <string>
#include <vector>
#include <unordered_map>
namespace nnvm {
namespace compiler {
using tvm::runtime::PackedFunc;
using AttrDict = std::unordered_map<std::string, std::string>;
/*!
* \brief Get PackedFunction from global registry and
* report error if it does not exist
* \param name The name of the function.
* \return The created PackedFunc.
*/
inline const PackedFunc& GetPackedFunc(const std::string& name) {
const PackedFunc* pf = tvm::runtime::Registry::Get(name);
CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
return *pf;
}
} // namespace compiler
} // namespace nnvm
// Enable the graph and symbol object exchange.
namespace tvm {
namespace runtime {
template<>
struct extension_type_info<nnvm::Symbol> {
static const int code = 16;
};
template<>
struct extension_type_info<nnvm::Graph> {
static const int code = 17;
};
template<>
struct extension_type_info<nnvm::compiler::AttrDict> {
static const int code = 18;
};
} // namespace runtime
} // namespace tvm
#endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file nnvm/compiler/util.h
* \brief Utility functions for nnvm compiler
*/
#ifndef NNVM_COMPILER_UTIL_H_
#define NNVM_COMPILER_UTIL_H_
#include <tvm/expr.h>
#include <nnvm/tuple.h>
namespace nnvm {
namespace compiler {
/*
* \brief Helper function to convert TShape to TVM array. Useful for
* passing data from NNVM param structures to TOPI ops.
*
* \param shape The shape to convert
*
* \return An Array of Expr, where each element is a constant int32
*/
inline tvm::Array<tvm::Expr> ShapeToArray(TShape shape) {
tvm::Array<tvm::Expr> result;
for (auto i : shape) {
result.push_back(tvm::make_const(tvm::DataType::Int(32), i));
}
return result;
}
/*
* \brief Helper function to convert TShape to TVM array. Useful for
* passing data from NNVM param structures to TOPI ops.
*
* \param shape The shape to convert
*
* \return An Array of Expr, where each element is a constant int32
*/
inline tvm::Array<tvm::Integer> ShapeToIntArray(TShape shape) {
return tvm::Downcast<tvm::Array<tvm::Integer> >(ShapeToArray(shape));
}
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_UTIL_H_
NNVM Core Operator and Compiler
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#!/usr/bin/env python
# coding: utf-8
"""NNVM python API for ease of use and help new framework establish python API. """
from __future__ import absolute_import as _abs
import warnings
from . import _base
from . import symbol as sym
from . import symbol
from ._base import NNVMError
from . import frontend
__version__ = _base.__version__
warnings.warn("NNVM is deprecated and will be removed in a future version. Use Relay instead.",
FutureWarning)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=invalid-name, unused-import
""" ctypes library of nnvm and helper functions """
from __future__ import absolute_import
import os
import sys
import ctypes
import numpy as np
from . import libinfo
try:
import tvm
except ImportError:
pass
#----------------------------
# library loading
#----------------------------
if sys.version_info[0] == 3:
string_types = str
numeric_types = (float, int, np.float32, np.int32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
else:
string_types = basestring
numeric_types = (float, int, long, np.float32, np.int32)
py_str = lambda x: x
class NNVMError(Exception):
"""Error that will be throwed by all nnvm functions"""
def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_LOCAL)
# DMatrix functions
lib.NNGetLastError.restype = ctypes.c_char_p
return lib
# version number
__version__ = libinfo.__version__
# library instance of nnvm
_LIB = _load_lib()
# The FFI mode of TVM
_FFI_MODE = os.environ.get("TVM_FFI", "auto")
# type definitions
nn_uint = ctypes.c_uint
OpHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p
GraphHandle = ctypes.c_void_p
# Global dict of str to symbol to initialize variables
_all_var_init = {}
#----------------------------
# helper function definition
#----------------------------
def check_call(ret):
"""Check the return value of C API call
This function will raise exception when error occurs.
Wrap every API call with this function
Parameters
----------
ret : int
return value from API calls
"""
if ret != 0:
raise NNVMError(py_str(_LIB.NNGetLastError()))
def c_str(string):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Create ctypes array from a python array
Parameters
----------
ctype : ctypes data type
data type of the array we want to convert to
values : tuple or list
data content
Returns
-------
out : ctypes array
Created ctypes array
"""
return (ctype * len(values))(*values)
def ctypes2buffer(cptr, length):
"""Convert ctypes pointer to buffer type.
Parameters
----------
cptr : ctypes.POINTER(ctypes.c_char)
pointer to the raw memory region
length : int
the length of the buffer
Returns
-------
buffer : bytearray
The raw byte memory buffer
"""
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
raise TypeError('expected char pointer')
res = bytearray(length)
rptr = (ctypes.c_char * length).from_buffer(res)
if not ctypes.memmove(rptr, cptr, length):
raise RuntimeError('memmove failed')
return res
def ctypes2numpy_shared(cptr, shape):
"""Convert a ctypes pointer to a numpy array
The result numpy array shares the memory with the pointer
Parameters
----------
cptr : ctypes.POINTER(mx_float)
pointer to the memory region
shape : tuple
shape of target ndarray
Returns
-------
out : numpy_array
A numpy array : numpy array
"""
if not isinstance(cptr, ctypes.POINTER(mx_float)):
raise RuntimeError('expected float pointer')
size = 1
for s in shape:
size *= s
dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)
def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.
num_args : nn_uint
Number of arguments.
arg_names : ctypes.POINTER(ctypes.c_char_p)
Argument names.
arg_types : ctypes.POINTER(ctypes.c_char_p)
Argument type information.
arg_descs : ctypes.POINTER(ctypes.c_char_p)
Argument description information.
remove_dup : boolean, optional
Whether remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for i in range(num_args.value):
key = py_str(arg_names[i])
if key in param_keys and remove_dup:
continue
param_keys.add(key)
type_info = py_str(arg_types[i])
ret = '%s : %s' % (key, type_info)
if arg_descs[i]:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
Ctypes specific implementation of certain modules
\ No newline at end of file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""""ctypes implementation of the Symbol"""
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines,
# pylint: disable=len-as-condition, consider-iterating-dictionary
"""Symbolic configuration API."""
from __future__ import absolute_import as _abs
import copy
import ctypes
import sys
from .._base import _LIB
from .._base import c_array, c_str, nn_uint, py_str
from .._base import SymbolHandle, OpHandle
from .._base import check_call, ctypes2docstring
from ..name import NameManager
from ..attribute import AttrScope
class SymbolBase(object):
"""Symbol is symbolic graph."""
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
self.handle = handle
def __del__(self):
check_call(_LIB.NNSymbolFree(self.handle))
def __call__(self, *args, **kwargs):
"""Invoke symbol as function on inputs.
Parameters
----------
args:
provide positional arguments
kwargs:
provide keyword arguments
Returns
-------
the resulting symbol
"""
s = copy.deepcopy(self)
s._compose(*args, **kwargs)
return s
def _compose(self, *args, **kwargs):
"""Compose symbol on inputs.
This call mutates the current symbol.
Parameters
----------
args:
provide positional arguments
kwargs:
provide keyword arguments
Returns
-------
the resulting symbol
"""
name = kwargs.pop('name', None)
if name:
name = c_str(name)
if len(args) != 0 and len(kwargs) != 0:
raise TypeError('compose only accept input Symbols \
either as positional or keyword arguments, not both')
for arg in args:
if not isinstance(arg, SymbolBase):
raise TypeError('Compose expect `Symbol` as arguments')
for val in kwargs.values():
if not isinstance(val, SymbolBase):
raise TypeError('Compose expect `Symbol` as arguments')
num_args = len(args) + len(kwargs)
if len(kwargs) != 0:
keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()])
args = c_array(SymbolHandle, [s.handle for s in kwargs.values()])
else:
keys = None
args = c_array(SymbolHandle, [s.handle for s in args])
check_call(_LIB.NNSymbolCompose(
self.handle, name, num_args, keys, args))
def _set_attr(self, **kwargs):
"""Set the attribute of the symbol.
Parameters
----------
**kwargs
The attributes to set
"""
keys = c_array(ctypes.c_char_p,
[c_str(key) for key in kwargs.keys()])
vals = c_array(ctypes.c_char_p,
[c_str(str(val)) for val in kwargs.values()])
num_args = nn_uint(len(kwargs))
check_call(_LIB.NNSymbolSetAttrs(
self.handle, num_args, keys, vals))
_symbol_cls = SymbolBase
def _set_symbol_class(cls):
global _symbol_cls
_symbol_cls = cls
def _make_atomic_symbol_function(handle, name):
"""Create an atomic symbol function by handle and funciton name."""
real_name = ctypes.c_char_p()
desc = ctypes.c_char_p()
num_args = nn_uint()
arg_names = ctypes.POINTER(ctypes.c_char_p)()
arg_types = ctypes.POINTER(ctypes.c_char_p)()
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
ret_type = ctypes.c_char_p()
check_call(_LIB.NNGetOpInfo(
handle, ctypes.byref(real_name), ctypes.byref(desc),
ctypes.byref(num_args),
ctypes.byref(arg_names),
ctypes.byref(arg_types),
ctypes.byref(arg_descs),
ctypes.byref(ret_type)))
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)
func_name = name
desc = py_str(desc.value)
doc_str = ('%s\n\n' +
'%s\n' +
'Returns\n' +
'-------\n' +
'result: Tensor\n' +
' The result Tensor.')
doc_str = doc_str % (desc, param_str)
def creator(*args, **kwargs):
"""Activation Operator of Neural Net.
The parameters listed below can be passed in as keyword arguments.
Parameters
----------
name : string, required.
Name of the resulting symbol.
Returns
-------
symbol: Symbol
the resulting symbol
"""
param_keys = []
param_vals = []
symbol_kwargs = {}
name = kwargs.pop('name', None)
attr = kwargs.pop('attr', None)
for k, v in kwargs.items():
if isinstance(v, SymbolBase):
symbol_kwargs[k] = v
else:
param_keys.append(c_str(k))
param_vals.append(c_str(str(v)))
# create atomic symbol
param_keys = c_array(ctypes.c_char_p, param_keys)
param_vals = c_array(ctypes.c_char_p, param_vals)
sym_handle = SymbolHandle()
check_call(_LIB.NNSymbolCreateAtomicSymbol(
handle,
nn_uint(len(param_keys)),
param_keys, param_vals,
ctypes.byref(sym_handle)))
if len(args) != 0 and len(symbol_kwargs) != 0:
raise TypeError(
'%s can only accept input'
'Symbols either as positional or keyword arguments, not both' % func_name)
s = _symbol_cls(sym_handle)
attr = AttrScope.current.get(attr)
if attr:
s._set_attr(**attr)
hint = func_name.lower()
name = NameManager.current.get(name, hint)
s._compose(*args, name=name, **symbol_kwargs)
return s
creator.__name__ = func_name
creator.__doc__ = doc_str
return creator
def _init_symbol_module(symbol_class, root_namespace):
"""List and add all the atomic symbol functions to current module."""
_set_symbol_class(symbol_class)
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.NNListAllOpNames(ctypes.byref(size),
ctypes.byref(plist)))
op_names = []
for i in range(size.value):
op_names.append(py_str(plist[i]))
module_obj = sys.modules["%s.symbol" % root_namespace]
module_obj_contrib = sys.modules["%s.contrib" % root_namespace]
module_internal = sys.modules["%s._symbol_internal" % root_namespace]
for name in op_names:
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
function = _make_atomic_symbol_function(hdl, name)
if function.__name__.startswith('_contrib_'):
setattr(module_obj_contrib, function.__name__.split('_contrib_')[1], function)
elif function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
setattr(module_obj, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
This folder is by default empty and will hold DLLs generated by cython.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace for cython generated modules for python2"""
This folder is by default empty and will hold DLLs generated by cython.
\ No newline at end of file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Cython generated modules"""
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Module space to register internal functions. Leave empty"""
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
"""Attribute scoping support for symbolic API."""
from __future__ import absolute_import
from ._base import string_types
class AttrScope(object):
"""Attribute manager for scoping.
User can also inherit this object to change naming behavior.
Parameters
----------
kwargs
The attributes to set for all symbol creations in the scope.
"""
current = None
def __init__(self, **kwargs):
self._old_scope = None
for value in kwargs.values():
if not isinstance(value, string_types):
raise ValueError("Attributes need to be string")
self._attr = kwargs
def get(self, attr):
"""
Get the attribute dict given the attribute set by the symbol.
Parameters
----------
attr : dict of string to string
The attribute passed in by user during symbol creation.
Returns
-------
attr : dict of string to string
Updated attributes to add other scope related attributes.
"""
if self._attr:
ret = self._attr.copy()
if attr:
ret.update(attr)
return ret
return attr
def __enter__(self):
# pylint: disable=protected-access
self._old_scope = AttrScope.current
attr = AttrScope.current._attr.copy()
attr.update(self._attr)
self._attr = attr
AttrScope.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
AttrScope.current = self._old_scope
AttrScope.current = AttrScope()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""NNVM compiler toolchain.
User only need to use :any:`build` and :any:`build_config` to do the compilation,
and :any:`save_param_dict` to save the parameters into bytes.
The other APIs are for more advanced interaction with the compiler toolchain.
"""
from __future__ import absolute_import
import tvm
from . import build_module
from . build_module import build, optimize, build_config
from . compile_engine import engine, graph_key
from . param_dict import save_param_dict, load_param_dict
from .. import symbol as _symbol
from .. import graph as _graph
from .. import top as _top
tvm.register_extension(_symbol.Symbol, _symbol.Symbol)
tvm.register_extension(_graph.Graph, _graph.Graph)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Compiler engine interface to internal engine
You can get the engine singleton at ``nnvm.compiler.engine``
"""
import tvm
_list_cache_items = tvm.get_global_func("nnvm.compiler.ListCacheItems")
_clear_cache = tvm.get_global_func("nnvm.compiler.ClearCache")
_get_cache_item = tvm.get_global_func("nnvm.compiler.GetCacheItem")
_set_cache_item = tvm.get_global_func("nnvm.compiler.SetCacheItem")
_graph_key_get_graph = tvm.get_global_func("nnvm.compiler.GraphKeyGetGraph")
_make_graph_key = tvm.get_global_func("nnvm.compiler.MakeGraphKey")
@tvm.register_node
class GraphKey(tvm.node.NodeBase):
"""Key of a graph compilation context"""
@property
def graph(self):
return _graph_key_get_graph(self)
@tvm.register_node
class GraphCacheEntry(tvm.node.NodeBase):
"""CacheEntry of compilation into a TVM Function"""
@tvm.register_node
class GraphFunc(tvm.node.NodeBase):
"""Compiled result of a graph into a TVM Function"""
class Engine(object):
"""Global singleton compilation engine.
You can get the singleton at ``nnvm.compiler.engine``
"""
def items(self):
"""List the available cache key value pairs.
Returns
-------
item_list : list of (GraphKey, GraphCacheEntry)
The existing cache items
"""
res = _list_cache_items()
assert len(res) % 2 == 0
return [(res[2*i], res[2*i+1]) for i in range(len(res) // 2)]
def clear_cache(self):
"""Clear the existing cached functions."""
_clear_cache()
def __setitem__(self, key, value):
"""Clear the existing cached functions."""
if isinstance(value, GraphCacheEntry):
_set_cache_item(key, value.graph_func)
else:
_set_cache_item(key, value)
def __getitem__(self, key):
"""Clear the existing cached functions."""
return _get_cache_item(key)
def dump(self):
"""Return a string representation of engine dump
Returns
-------
dump : str
The dumped string representation
"""
items = self.items()
res = "====================================\n"
res += "CompilerEngine dump, %d items cached\n" % len(items)
for key, value in items:
res += "------------------------------------\n"
res += "target={}\n".format(key.target)
res += "inputs={}\n".format(key.inputs)
res += "use_count={}\n".format(value.use_count)
res += "func_name={}\n".format(value.graph_func.func_name)
res += key.graph.ir() + "\n"
res += "===================================\n"
return res
engine = Engine()
def graph_key(graph, inputs, target):
"""Construct a new graph key.
Parameters
----------
graph : Graph
The computation graph structure
inputs : list of Tensor(placeholder)
The input requirement to the graph.
target : str
The target of compilation.
"""
return _make_graph_key(graph, inputs, target)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Utilities to access graph attributes"""
from __future__ import absolute_import as _abs
import tvm
def set_shape_inputs(g, shape):
"""Set the shape of input graph nodes in the graph attribute.
Parameters
----------
g : Graph
The input graph
shape : dict of str to tuple
The input shape
Returns
-------
g : Graph
The updated graph with updated shape.
"""
list_shape = [
shape.get(name, ()) for name in g.index.input_names]
g._set_json_attr("shape_inputs", list_shape, 'list_shape')
return g
DTYPE_TO_TCODE = {
"default": -1,
"float32": 0,
"float64": 1,
"float16": 2,
"uint8": 3,
"int32": 4,
"int8": 5,
"int64": 6,
"int16": 7,
"uint16": 8,
"uint32": 9,
"uint64": 10,
"bool": 11,
}
TCODE_TO_DTYPE = {
-1: None,
0: "float32",
1: "float64",
2: "float16",
3: "uint8",
4: "int32",
5: "int8",
6: "int64",
7: "int16",
8: "uint16",
9: "uint32",
10: "uint64",
11: "bool",
}
def set_dtype_inputs(g, dtype):
"""Set the dtype inputs of graph nodes
Parameters
----------
g : Graph
The input graph
dtype : dict of str to str or str
The input dtype
Returns
-------
g : Graph
The updated graph with updated dtype.
"""
if isinstance(dtype, dict):
list_dtype = [
DTYPE_TO_TCODE[str(dtype.get(name, "default"))]
for name in g.index.input_names]
else:
list_dtype = [DTYPE_TO_TCODE[dtype]] * len(g.index.input_names)
g._set_json_attr("dtype_inputs", list_dtype, "list_int")
return g
def set_layout_inputs(g, layout):
"""Set the layout inputs of graph nodes
Parameters
----------
g : Graph
The input graph
layout : dict of str to str or str
The input layout
Returns
-------
g : Graph
The updated graph with updated layout.
"""
if isinstance(layout, dict):
list_layout = [
layout.get(name, "__undef__") for name in g.index.input_names]
elif isinstance(layout, str):
list_layout = ["__undef__"] * len(g.index.input_names)
list_layout[0] = layout
else:
raise ValueError("Input layout must be str or dict")
last_inferred_layouts = g.json_attr("layout")
if last_inferred_layouts:
input_layout = [last_inferred_layouts[g.index.entry_id(x)] for x in g.index.input_names]
for i, layout_stored in enumerate(input_layout):
list_layout[i] = list_layout[i] if list_layout[i] != '__undef__' else layout_stored
g._set_json_attr("layout_inputs", list_layout, 'list_layout')
return g
_move_out_module = tvm.get_global_func("nnvm.graph._move_module")
_move_out_graph = tvm.get_global_func("nnvm.graph._move_graph")
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Namespace of graph pass.
Principle:
- Graph in, graph out: always takes in graph as first argument and returns a graph
- Composable API: break graph transformation pass as segments of small transformations.
"""
from __future__ import absolute_import as _abs
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Utility function to get information from graph."""
from __future__ import absolute_import as _abs
import tvm
from . import graph_attr
from ..graph import create
from ..symbol import Group, ones_like
def infer_shape(graph, **shape):
"""Infer the shape given the shape of inputs.
Parameters
----------
graph : Graph
The graph to perform shape inference from
shape : dict of str to tuple
The specific input shape.
Returns
-------
in_shape : list of tuple
Shape of inputs
out_shape: list of tuple
Shape of outputs
"""
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
shape = graph.json_attr("shape")
index = graph.index
input_shape = [shape[index.entry_id(x)] for x in index.input_names]
output_shape = [shape[index.entry_id(x)] for x in index.output_entries]
return input_shape, output_shape
def infer_dtype(graph, **dtype):
"""Infer the type given the typeS of inputs.
Parameters
----------
graph : Graph
The graph to perform type inference from
dtype : dict of str to dtype
The specific input data type.
Returns
-------
in_dtype : list of tuple
Dtype of inputs
out_dtype: list of tuple
Dtype of outputs
"""
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph = graph.apply("InferType")
dtype = graph.json_attr("dtype")
index = graph.index
input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.input_names]
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.output_entries]
return input_dtype, output_dtype
_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")
def check_graph_equal(grapha, graphb, compare_variable_attrs=False):
"""Check if two graphs have equal structure.
Parameters
----------
grapha : Graph
The first graph
graphb : Graph
The second graph
compare_variable_attrs : bool, optional
Whether we want to compare attributes(names) on variables.
Usually it is safe to skip it unless we want input name
to exactly match
Raises
------
ValueError
ValueError is raised with error message when graph not equal
"""
err = _deep_compare(grapha, graphb, compare_variable_attrs)
if err:
raise ValueError("Graph compare error: " + err)
def get_gradient_graph(ys, xs, grad_ys=None):
"""Create gradient graph of ys with respect to xs.
Parameters
----------
ys : Symbol or list of Symbol
Symbols from which the gradient is calculated.
xs : Symbol or list of Symbol
Symbols the gradient respect to.
For group symbol, gradients for all outputs will be calculated.
grad_ys : Symbol or list of Symbol
Head gradients for ys.
Returns
-------
ret : Graph
Generated gradient graph.
"""
if isinstance(ys, list):
ys = Group(ys)
g = create(ys)
g._set_symbol_list_attr('grad_ys', ys)
g._set_symbol_list_attr('grad_xs', xs)
ny = len(ys.list_output_names())
if grad_ys is None:
grad_ys = [ones_like(ys[i]) for i in range(ny)]
g._set_symbol_list_attr('grad_ys_out_grad', grad_ys)
return g.apply('Gradient')
def gradients(ys, xs, grad_ys=None):
"""Create gradient symbol of ys respect to xs.
Parameters
----------
ys : Symbol or list of Symbol
Symbols from which the gradient is calculated.
xs : Symbol or list of Symbol
Symbols the gradient respect to.
For group symbol, gradients for all outputs will be calculated.
grad_ys : Symbol or list of Symbol
Head gradients for ys.
Returns
-------
ret : list of Symbol
Generated gradient symbol. For each xs,
all gradients from ys are merged into a single symbol.
"""
grad_g = get_gradient_graph(ys, xs, grad_ys)
nx = len(Group(xs).list_output_names()) \
if isinstance(xs, list) else len(xs.list_output_names())
ret = [grad_g.symbol[i] for i in range(nx)]
return ret
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-few-public-methods, no-member
"""API for scheduling learning rate."""
from .. import symbol as sym
class LRScheduler(object):
"""Base class of a learning rate scheduler.
A scheduler returns a new learning rate based on the number of updates that have
been performed.
Parameters
----------
base_lr : float, optional
The initial learning rate.
"""
def __init__(self, base_lr=0.01, name='LRScheduler'):
self.name = name
self.base_lr = base_lr
def __call__(self, num_update):
"""Return a new learning rate based on number of updates.
Parameters
----------
num_update: nnvm Symbol
the number of updates applied to weight.
"""
raise NotImplementedError("__call__ method must be overridden.")
class FactorScheduler(LRScheduler):
"""Reduce the learning rate by a factor for every *n* steps.
It returns a new learning rate by::
base_lr * pow(factor, num_update/step)
Parameters
----------
step : int
Changes the learning rate for every n updates.
factor : float, optional
The factor to change the learning rate.
stop_factor_lr : float, optional
Stop updating the learning rate if it is less than this value.
"""
def __init__(self, step, factor=1, stop_factor_lr=1e-8, name='FactorScheduler', **kwargs):
super(FactorScheduler, self).__init__(name=name, **kwargs)
if step < 1:
raise ValueError("Schedule step must be greater or equal than 1 round")
if factor > 1.0:
raise ValueError("Factor must be no more than 1 to make lr reduce")
self.step = step
self.factor = factor
self.stop_factor_lr = stop_factor_lr
def __call__(self, num_update):
updated_lr = self.base_lr * self.factor ** (num_update / self.step)
return sym.clip(updated_lr, a_min=self.stop_factor_lr, a_max=self.base_lr)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, no-member, too-few-public-methods, too-many-arguments, too-many-locals, protected-access
"""Optimizer API"""
from . import graph_util
from .. import symbol as sym
class Optimizer(object):
"""Base class inherited by all optimizers.
Parameters
----------
learning_rate : float, optional
The initial learning rate.
lr_scheduler : LRScheduler, optional
The learning rate scheduler.
rescale_grad : float, optional
Multiply the gradient with `rescale_grad` before updating. Often
choose to be ``1.0/batch_size``.
clip_gradient : float, optional
Clip the gradient by projecting onto the box ``[-clip_gradient, clip_gradient]``.
wd : float, optional
The weight decay (or L2 regularization) coefficient. Modifies objective
by adding a penalty for having large weights.
name : string, optional
The name of optimizer.
"""
def __init__(self, learning_rate=0.01, lr_scheduler=None,
rescale_grad=1, clip_gradient=None, wd=0, name="Optimizer"):
self.name = name
self.lr = learning_rate
self.lr_scheduler = lr_scheduler
self.rescale_grad = rescale_grad
self.clip_gradient = clip_gradient
self.wd = wd
init_update_t = sym.Variable(name+'_t', init=sym.zeros(shape=(1,), dtype="int32"))
self.update_t = sym._assign(init_update_t, init_update_t + 1)
def minimize(self, obj, var=None):
"""Minimize given obj symbol respect to var. If var is not set, all input
variables of obj will be used.
Parameters
----------
obj : nnvm Symbol or list of nnvm Symbols
Symbols to be minimized.
var : nnvm Symbol or list of nnvm Symbols, optional
Symbols the gradient respect to.
Returns
-------
group_sym : nnvm Symbol
Group symbol represents update symbols.
"""
raise NotImplementedError()
def _get_lr(self):
"""Gets the learning rate with learning rate scheduler.
Returns
-------
lr : float
Learning rate.
"""
if self.lr_scheduler is not None:
lr = self.lr_scheduler(self.update_t)
else:
lr = self.lr
return lr
class SGD(Optimizer):
"""The SGD optimizer
"""
def __init__(self, name='SGD', **kwargs):
super(SGD, self).__init__(name=name, **kwargs)
def minimize(self, obj, var=None):
variables = var or obj.list_input_variables()
if not isinstance(variables, list):
variables = [variables]
grads = graph_util.gradients(obj, variables)
updates = []
lr_t = self._get_lr()
for v, g in zip(variables, grads):
g = self.rescale_grad * g
if self.clip_gradient is not None:
g = sym.clip(g, a_min=-1 * self.clip_gradient, a_max=self.clip_gradient)
updates.append(sym._assign(v, v - lr_t * (g + self.wd * v)))
return sym.Group(updates)
class Adam(Optimizer):
"""The Adam optimizer.
This class implements the optimizer described in *Adam: A Method for
Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999,
epsilon=1e-8, name='Adam', **kwargs):
super(Adam, self).__init__(learning_rate=learning_rate, name=name, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.m = []
self.v = []
def minimize(self, obj, var=None):
variables = var or obj.list_input_variables()
if not isinstance(variables, list):
variables = [variables]
grads = graph_util.gradients(obj, variables)
updates = []
for i, v in enumerate(variables):
self.m.append(sym.Variable(self.name + '_m' + str(i), init=sym.zeros_like(v)))
self.v.append(sym.Variable(self.name + '_v' + str(i), init=sym.zeros_like(v)))
rate = sym.sqrt(1 - self.beta2 ** self.update_t) / (1 - self.beta1 ** self.update_t)
lr_t = self._get_lr() * rate
for variable, g, m, v in zip(variables, grads, self.m, self.v):
g = self.rescale_grad * g
if self.clip_gradient is not None:
g = sym.clip(g, a_min=-1 * self.clip_gradient, a_max=self.clip_gradient)
update_m = sym._assign(m, self.beta1 * m + (1 - self.beta1) * g)
update_v = sym._assign(v, self.beta2 * v + (1 - self.beta2) * g * g)
update_var = sym._assign(variable, variable - lr_t * (update_m / (sym.sqrt(update_v) \
+ self.epsilon) + self.wd * variable))
updates.append(update_var)
return sym.Group(updates)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Helper utility to save parameter dict"""
import tvm
_save_param_dict = tvm.get_global_func("nnvm.compiler._save_param_dict")
_load_param_dict = tvm.get_global_func("nnvm.compiler._load_param_dict")
def save_param_dict(params):
"""Save parameter dictionary to binary bytes.
The result binary bytes can be loaded by the
GraphModule with API "load_params".
Parameters
----------
params : dict of str to NDArray
The parameter dictionary.
Returns
-------
param_bytes: bytearray
Serialized parameters.
Examples
--------
.. code-block:: python
# compile and save the modules to file.
graph, lib, params = nnvm.compiler.build(
graph, target, shape={"data", data_shape}, params=params)
module = graph_runtime.create(graph, lib, tvm.gpu(0))
# save the parameters as byte array
param_bytes = nnvm.compiler.save_param_dict(params)
# We can serialize the param_bytes and load it back later.
# Pass in byte array to module to directly set parameters
module["load_params"](param_bytes)
"""
args = []
for k, v in params.items():
args.append(k)
args.append(tvm.nd.array(v))
return _save_param_dict(*args)
def load_param_dict(param_bytes):
"""Load parameter dictionary to binary bytes.
Parameters
----------
param_bytes: bytearray
Serialized parameters.
Returns
-------
params : dict of str to NDArray
The parameter dictionary.
"""
if isinstance(param_bytes, (bytes, str)):
param_bytes = bytearray(param_bytes)
load_arr = _load_param_dict(param_bytes)
return {v.name : v.array for v in load_arr}
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Module space to register contrib functions. Leave empty"""
Cython specific implementation of certain modules
\ No newline at end of file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
ctypedef void* SymbolHandle
ctypedef void* OpHandle
ctypedef unsigned nn_uint
cdef py_str(const char* x):
if PY_MAJOR_VERSION < 3:
return x
else:
return x.decode("utf-8")
cdef c_str(pystr):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return pystr.encode("utf-8")
cdef CALL(int ret):
if ret != 0:
raise NNVMError(NNGetLastError())
cdef const char** CBeginPtr(vector[const char*]& vec):
if (vec.size() != 0):
return &vec[0]
else:
return NULL
cdef vector[const char*] SVec2Ptr(vector[string]& vec):
cdef vector[const char*] svec
svec.resize(vec.size())
for i in range(vec.size()):
svec[i] = vec[i].c_str()
return svec
cdef BuildDoc(nn_uint num_args,
const char** arg_names,
const char** arg_types,
const char** arg_descs,
remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.
num_args : nn_uint
Number of arguments.
arg_names : ctypes.POINTER(ctypes.c_char_p)
Argument names.
arg_types : ctypes.POINTER(ctypes.c_char_p)
Argument type information.
arg_descs : ctypes.POINTER(ctypes.c_char_p)
Argument description information.
remove_dup : boolean, optional
Whether remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for i in range(num_args):
key = arg_names[i]
if key in param_keys and remove_dup:
continue
param_keys.add(key)
type_info = arg_types[i]
ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import as _abs
import sys as _sys
import ctypes as _ctypes
from numbers import Number as _Number
from .._base import NNVMError
from ..name import NameManager
from ..attribute import AttrScope
from libcpp.vector cimport vector
from libcpp.string cimport string
from cpython.version cimport PY_MAJOR_VERSION
include "./base.pyi"
cdef extern from "nnvm/c_api.h":
const char* NNGetLastError();
int NNListAllOpNames(nn_uint *out_size,
const char ***out_array);
int NNGetOpHandle(const char *op_name,
OpHandle *handle);
int NNGetOpInfo(OpHandle op,
const char **name,
const char **description,
nn_uint *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
int NNListOpNames(nn_uint *out_size,
const char ***out_array);
int NNSymbolCreateAtomicSymbol(OpHandle op,
nn_uint num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
int NNSymbolFree(SymbolHandle symbol);
int NNSymbolSetAttrs(SymbolHandle symbol,
nn_uint num_param,
const char** keys,
const char** values);
int NNSymbolCompose(SymbolHandle sym,
const char* name,
nn_uint num_args,
const char** keys,
SymbolHandle* args);
cdef class SymbolBase:
"""Symbol is symbolic graph."""
# handle for symbolic operator.
cdef SymbolHandle handle
def __init__(self, handle):
cdef unsigned long ptr
if handle is None:
self.handle = NULL
else:
ptr = handle.value
self.handle = <SymbolHandle>(ptr)
def __dealloc__(self):
CALL(NNSymbolFree(self.handle))
@property
def handle(self):
return _ctypes.cast(<unsigned long>self.handle, _ctypes.c_void_p)
def _set_attr(self, **kwargs):
"""Set the attribute of the symbol.
Parameters
----------
**kwargs
The attributes to set
"""
SymbolSetAttr(self.handle, kwargs)
cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
cdef vector[string] sparam_keys
cdef vector[string] sparam_vals
cdef nn_uint num_args
for k, v in kwargs.items():
sparam_keys.push_back(c_str(k))
sparam_vals.push_back(c_str(str(v)))
# keep strings in vector
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
num_args = param_keys.size()
CALL(NNSymbolSetAttrs(
handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals)))
_symbol_cls = SymbolBase
cdef _set_symbol_class(cls):
global _symbol_cls
_symbol_cls = cls
cdef NewSymbol(SymbolHandle handle):
"""Create a new symbol given handle"""
sym = _symbol_cls(None)
(<SymbolBase>sym).handle = handle
return sym
cdef _make_atomic_symbol_function(OpHandle handle, string name):
"""Create an atomic symbol function by handle and funciton name."""
cdef const char *real_name
cdef const char *desc
cdef nn_uint num_args
cdef const char** arg_names
cdef const char** arg_types
cdef const char** arg_descs
cdef const char* return_type
CALL(NNGetOpInfo(
handle, &real_name, &desc,
&num_args, &arg_names,
&arg_types, &arg_descs,
&return_type))
param_str = BuildDoc(num_args, arg_names, arg_types, arg_descs)
func_name = py_str(name.c_str())
doc_str = ('%s\n\n' +
'%s\n' +
'Returns\n' +
'-------\n' +
'result: Tensor\n' +
' The result Tensor.')
doc_str = doc_str % (desc, param_str)
func_hint = func_name.lower()
def creator(*args, **kwargs):
cdef vector[string] sparam_keys
cdef vector[string] sparam_vals
cdef vector[SymbolHandle] symbol_args
cdef vector[string] ssymbol_keys
cdef SymbolHandle ret_handle
name = kwargs.pop("name", None)
attr = kwargs.pop("attr", None)
if len(kwargs) != 0:
for k, v in kwargs.items():
if isinstance(v, SymbolBase):
ssymbol_keys.push_back(c_str(k))
symbol_args.push_back((<SymbolBase>v).handle)
else:
sparam_keys.push_back(c_str(k))
sparam_vals.push_back(c_str(str(v)))
if len(args) != 0:
if symbol_args.size() != 0:
raise TypeError("compose only accept input Symbols\
either as positional or keyword arguments, not both")
for v in args:
if not isinstance(v, SymbolBase):
raise TypeError('Compose expect `Symbol` as arguments')
symbol_args.push_back((<SymbolBase>v).handle)
cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys)
cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals)
cdef vector[const char*] symbol_keys = SVec2Ptr(ssymbol_keys)
CALL(NNSymbolCreateAtomicSymbol(
handle,
<nn_uint>param_keys.size(),
CBeginPtr(param_keys),
CBeginPtr(param_vals),
&ret_handle))
num_args = <nn_uint>(symbol_args.size())
attr = AttrScope.current.get(attr)
if attr:
SymbolSetAttr(ret_handle, attr)
name = NameManager.current.get(name, func_hint)
cdef const char* c_name = NULL
if name:
name = c_str(name)
c_name = name
CALL(NNSymbolCompose(
ret_handle,
c_name,
num_args,
&symbol_keys[0] if symbol_keys.size() != 0 else NULL,
&symbol_args[0] if symbol_args.size() != 0 else NULL))
return NewSymbol(ret_handle)
creator.__name__ = func_name
creator.__doc__ = doc_str
return creator
def _init_symbol_module(symbol_class, root_namespace):
"""List and add all the atomic symbol functions to current module."""
cdef const char** op_name_ptrs
cdef nn_uint size
cdef vector[string] op_names
cdef OpHandle handle
_set_symbol_class(symbol_class)
CALL(NNListAllOpNames(&size, &op_name_ptrs))
for i in range(size):
op_names.push_back(string(op_name_ptrs[i]));
module_obj = _sys.modules["%s.symbol" % root_namespace]
module_internal = _sys.modules["%s._symbol_internal" % root_namespace]
for i in range(op_names.size()):
CALL(NNGetOpHandle(op_names[i].c_str(), &handle))
function = _make_atomic_symbol_function(handle, op_names[i])
if function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
setattr(module_obj, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""NNVM frontends."""
from __future__ import absolute_import
from .mxnet import from_mxnet
from .onnx import from_onnx
from .coreml import from_coreml
from .keras import from_keras
from .darknet import from_darknet
from .tensorflow import from_tensorflow
from .caffe2 import from_caffe2
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Shared functions and classes for frontends."""
from __future__ import absolute_import as _abs
import logging
from nnvm import sym as _sym
from .._base import string_types
def get_nnvm_op(op_name):
op = getattr(_sym, op_name)
if not op:
raise OpNotImplemented(
'Operator {} is not supported.'.format(op))
return op
def required_attr(attr, key, op_name):
assert isinstance(attr, dict)
if key not in attr:
raise OpAttributeRequired(
'Required attribute {} not found in operator {}'.format(key, op_name))
return attr[key]
def parse_tshape(tshape):
"""Parse tshape in string."""
return [int(x.strip()) for x in tshape.strip('()').split(',')]
def parse_bool_str(attr, key, default='False'):
"""Parse bool string to boolean."""
return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes']
class Renamer(object):
"""A simply renamer for operators.
Parameters
----------
new_name : str
The new name for the operator
"""
def __init__(self, new_name):
self._new_name = new_name
def __call__(self, inputs, attrs, *args):
return get_nnvm_op(self._new_name)(*inputs, **attrs)
class AttrConverter(object):
"""Common attribute converter. An AttrConverter instance is a callable:
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
```
Parameters
----------
op_name : str or callable
If set as str, returned operator name is the str.
If set as callable, returned operator is the str returned by calling:
`op_name = func(attr)`
transforms : dict of `new_name, or (new_name, default_value, transform function)`
If only a new_name is provided, it's like renaming the attribute name.
If default_value if provided, then the attribute is considered as optional.
If transform function is provided, the original attribute value is handled
by transform function.
excludes : list
A list of excluded attributes that should `NOT` appear.
Raise NotImplementedError if occurred.
disables : list
A list of attributes that is disabled in nnvm. Log warnings.
ignores : list
A list of attributes that is ignored in nnvm. Debug level logging.
extras : dict
A series of additional attributes should be added anyway to the returned
attribute dict.
custom_check : callable
A custom function takes attribute, and return True/False.
Raise RuntimeError if not bool(True) returned.
"""
def __init__(self, op_name, transforms=None,
excludes=None, disables=None, ignores=None,
extras=None, custom_check=None):
self._op_name = op_name
self._transforms = transforms if transforms else {}
self._excludes = excludes if excludes else []
self._disables = disables if disables else []
self._ignores = ignores if ignores else []
self._extras = extras if extras else {}
self._custom_check = custom_check
def __call__(self, inputs, attrs, *args):
# apply custom check
if self._custom_check:
func, msg = self._custom_check
if not func(attrs):
raise RuntimeError("Check failed: {}".format(msg))
# get new op_name
if isinstance(self._op_name, string_types):
op_name = self._op_name
else:
assert callable(self._op_name), "op_name can either be string or callable"
op_name = self._op_name(attrs)
# convert attributes
new_attrs = {}
for k in attrs.keys():
if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k))
elif k in self._disables:
logging.warning("Attribute %s is disabled in nnvm.sym.%s", k, op_name)
elif k in self._ignores:
logging.debug("Attribute %s is ignored in nnvm.sym.%s", k, op_name)
elif k in self._transforms:
new_name, defaults, transform = self._parse_default(self._transforms[k])
if defaults is None:
new_attr = self._required_attr(attrs, k)
else:
new_attr = attrs.get(k, None)
if new_attr is None:
new_attrs[new_name] = defaults
else:
new_attrs[new_name] = transform(new_attr)
else:
# copy
new_attrs[k] = attrs[k]
# add extras
new_attrs.update(self._extras)
return get_nnvm_op(op_name)(*inputs, **new_attrs)
def _parse_default(self, target):
"""Helper function to parse default values."""
if not isinstance(target, (list, tuple)):
k, v, t = target, None, lambda x: x
elif len(target) == 1:
k, v, t = target[0], None, lambda x: x
elif len(target) == 2:
k, v, t = target[0], target[1], lambda x: x
elif len(target) > 2:
k, v, t = target[0], target[1], target[2]
else:
k = None # should raise
if not isinstance(k, string_types):
msg = "{} is not a valid target, (name, default) expected.".format(target)
raise ValueError(msg)
return k, v, t
def _parse_bool(self, value):
"""Helper function to parse default boolean values."""
if isinstance(value, string_types):
return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
return bool(value)
def _required_attr(self, attr, key):
"""Wrapper for getting required attributes."""
assert isinstance(attr, dict)
if key not in attr:
raise AttributeError("Required attribute {} not found.".format(key))
return attr[key]
class SymbolTable(object):
"""Table storing symbols by names."""
def __init__(self):
self.vars = {}
self.params = {}
self.const_ctr = 1
self.in_padding = False
self.paddings = [0, 0]
def new_const(self, value):
name = "_param_%d" % (self.const_ctr)
self.const_ctr += 1
self.params[name] = value
self.vars[name] = _sym.Variable(name=name)
return self.vars[name]
def get_var(self, name, must_contain=True):
if must_contain:
assert name in self.vars
if name not in self.vars:
self.vars[name] = _sym.Variable(name=name)
return self.vars[name]
def set_var(self, name, sym):
assert isinstance(sym, _sym.Symbol)
self.vars[name] = sym
def set_padding(self, paddings):
self.paddings = paddings
self.in_padding = True
def clear_padding(self):
self.in_padding = False
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Util functions shared by the ONNX and Caffe2 frontends."""
from __future__ import absolute_import as _abs
from nnvm import graph as _graph
from nnvm.compiler import graph_util
def dimension_picker(prefix, surfix=''):
def _impl(attr):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
raise NotImplementedError("Only 2d kernel supported.")
return _impl
def dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
return True
return False
return _dim_check, "Only 2d kernel supported."
def infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 don't provide
these attributes. We check the shape of weights provided to get the number.
"""
g = _graph.create(inputs)
shape_dict = {k: v.shape for k, v in params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
return channels
def revert_caffe2_pad(pads):
"""Caffe2 require two times the normal padding."""
if len(pads) == 4:
pads = pads[:2]
elif len(pads) == 2:
pass
else:
raise ValueError("Invalid caffe2 type padding: {}".format(pads))
return pads
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
"""NNVM Graph IR API.
This is a developer API that is used to manipulate and transform graphs.
"""
from __future__ import absolute_import as _abs
import ctypes
import json
from ._base import _LIB
from ._base import c_array, c_str, nn_uint, py_str, string_types
from ._base import GraphHandle, SymbolHandle
from ._base import check_call
from .symbol import Variable, Symbol, Group as _Group
class GraphIndex(object):
"""Index for quickly accessing graph attributes.
Parameters
----------
graph : Graph
The graph to create index.
"""
def __init__(self, graph):
jgraph = json.loads(create(graph).apply("SaveJSON").json_attr("json"))
self.nodes = jgraph["nodes"]
self.entry_ptr = jgraph["node_row_ptr"]
self._name2nodeid = {n["name"]: i for i, n in enumerate(self.nodes)}
self.input_names = graph.symbol.list_input_names()
self.output_entries = jgraph["heads"]
@property
def num_nodes(self):
"""Number of nodes in graph."""
return len(self.entry_ptr) - 1
@property
def num_node_entries(self):
"""Number of nodes in graph."""
return self.entry_ptr[-1]
def node_id(self, key):
"""Get the node index for a given key.
Parameters
----------
key : str or int
The node key or index
Returns
-------
index : int
The entry index
"""
return self._name2nodeid[key]
def entry_id(self, key, value_index=0):
"""Get the entry id of a node entry.
Parameters
----------
key : str or int
The node key or index
value_index : int
The value index of output
Returns
-------
index : int
The entry index
"""
if isinstance(key, (list, tuple)):
if len(key) != 3:
raise ValueError("Expect entry index to be tuple of 3 elems")
key, value_index, _ = key
idx = self.node_id(key) if isinstance(key, str) else key
assert value_index < self.entry_ptr[idx + 1]
return self.entry_ptr[idx] + value_index
class Graph(object):
"""Graph is the graph object that can be used to apply optimization pass.
It contains additional graphwise attribute besides the internal symbol.
"""
_tvm_tcode = 17
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : GraphHandle
the handle to the underlying C++ Graph
"""
self.handle = handle
self._index = None
def __del__(self):
check_call(_LIB.NNGraphFree(self.handle))
def json_attr(self, key):
"""Get attribute string from the graph.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
"""
ret = ctypes.c_char_p()
success = ctypes.c_int()
check_call(_LIB.NNGraphGetJSONAttr(
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
json_str = py_str(ret.value)
return json.loads(json_str)[1]
return None
def _set_symbol_list_attr(self, key, value):
"""Set the attribute of the graph.
Parameters
----------
key : string
The key of the attribute
value : value
The any type that can be dumped to json
type_name : string
The typename registered on c++ side.
"""
if isinstance(value, list):
value = _Group(value)
if not isinstance(value, Symbol):
raise ValueError("value need to be grouped symbol")
check_call(_LIB.NNGraphSetNodeEntryListAttr_(
self.handle, c_str(key), value.handle))
def _set_json_attr(self, key, value, type_name=None):
"""Set the attribute of the graph.
Parameters
----------
key : string
The key of the attribute
value : value
The any type that can be dumped to json
type_name : string
The typename registered on c++ side.
"""
if isinstance(value, string_types):
type_name = 'str'
elif type_name is None:
raise ValueError("Need to specify type_name")
json_value = json.dumps([type_name, value])
check_call(_LIB.NNGraphSetJSONAttr(
self.handle, c_str(key), c_str(json_value)))
@property
def _tvm_handle(self):
return self.handle.value
@property
def symbol(self):
shandle = SymbolHandle()
check_call(_LIB.NNGraphGetSymbol(self.handle, ctypes.byref(shandle)))
return Symbol(shandle)
def json(self):
"""Get JSON representation of the graph
Returns
-------
json : str
JSON representation of the graph
"""
return self.apply("SaveJSON").json_attr("json")
def _tvm_graph_json(self):
"""Get TVM graph json"""
return self.json()
@property
def index(self):
if not self._index:
self._index = GraphIndex(self)
return self._index
def ir(self, join_entry_attrs=None, join_node_attrs=None):
"""Get text form of graph ir.
Parameters
----------
join_entry_attrs : list of str
List of graph NodeEntry attribute to be
printed along each operator.
join_node_attrs : list of str
List of graph node attribute to be
printed along each operator.
"""
if join_entry_attrs:
self._set_json_attr("join_entry_attrs", join_entry_attrs, "list_str")
if join_node_attrs:
self._set_json_attr("join_node_attrs", join_node_attrs, "list_str")
return self.apply("PrintGraphIR").json_attr("graphir")
def apply(self, passes):
"""Apply passes to the graph
Parameters
----------
passes : str or list of str
The passes to be applied
Returns
-------
g : Graph
The transformed graph.
"""
if isinstance(passes, string_types):
passes = [passes]
cpass = c_array(ctypes.c_char_p, [c_str(key) for key in passes])
ghandle = GraphHandle()
npass = nn_uint(len(passes))
check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
return Graph(ghandle)
def load_json(json_str):
"""Create a new graph by loading from json
Parameters
----------
json_str : str
The json string
Returns
-------
graph : Graph
The loaded graph
"""
ret = create(Variable("x"))
ret._set_json_attr("json", json_str)
return ret.apply("LoadJSON")
def create(symbol):
"""Create a new graph from symbol.
Parameters
----------
symbol : Symbol
The symbolic graph used to create Graph object.
Returns
-------
graph : Graph
A generated new graph object.
"""
ghandle = GraphHandle()
check_call(_LIB.NNGraphCreate(
symbol.handle, ctypes.byref(ghandle)))
return Graph(ghandle)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
"""Information about nnvm."""
from __future__ import absolute_import
import sys
import os
import platform
if sys.version_info[0] == 3:
import builtins as __builtin__
else:
import __builtin__
def find_lib_path():
"""Find NNNet dynamic library files.
Returns
-------
lib_path : list(string)
List of all found path to the libraries
"""
if hasattr(__builtin__, "NNVM_BASE_PATH"):
base_path = __builtin__.NNVM_BASE_PATH
else:
base_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
if hasattr(__builtin__, "NNVM_LIBRARY_NAME"):
lib_name = __builtin__.NNVM_LIBRARY_NAME
else:
lib_name = "nnvm_compiler" if sys.platform.startswith('win32') else "libnnvm_compiler"
api_path = os.path.join(base_path, '..', '..', 'lib')
cmake_build_path_win = os.path.join(base_path, '..', '..', '..', 'build', 'Release')
cmake_build_path = os.path.join(base_path, '..', '..', '..', 'build')
install_path = os.path.join(base_path, '..', '..', '..')
dll_path = [base_path, api_path, cmake_build_path_win, cmake_build_path,
install_path]
if sys.platform.startswith('linux') and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
elif sys.platform.startswith('darwin') and os.environ.get('DYLD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['DYLD_LIBRARY_PATH'].split(":")])
elif sys.platform.startswith('win32') and os.environ.get('PATH', None):
dll_path.extend([p.strip() for p in os.environ['PATH'].split(";")])
if sys.platform.startswith('win32'):
vs_configuration = 'Release'
if platform.architecture()[0] == '64bit':
dll_path.append(os.path.join(base_path, '..', '..', '..', 'build', vs_configuration))
dll_path.append(os.path.join(base_path, '..', '..', '..', 'windows', 'x64',
vs_configuration))
else:
dll_path.append(os.path.join(base_path, '..', '..', '..', 'build', vs_configuration))
dll_path.append(os.path.join(base_path, '..', '..', '..', 'windows', vs_configuration))
dll_path = [os.path.join(p, '%s.dll' % lib_name) for p in dll_path]
elif sys.platform.startswith('darwin'):
dll_path = [os.path.join(p, '%s.dylib' % lib_name) for p in dll_path]
else:
dll_path = [os.path.join(p, '%s.so' % lib_name) for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if not lib_path:
raise RuntimeError('Cannot find the files.\n' +
'List of candidates:\n' + str('\n'.join(dll_path)))
return lib_path
# current version
__version__ = "0.8.0"
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
"""Automatic naming support for symbolic API."""
from __future__ import absolute_import as _abs
class NameManager(object):
"""NameManager to do automatic naming.
User can also inherit this object to change naming behavior.
"""
current = None
def __init__(self):
self._counter = {}
self._old_manager = None
def get(self, name, hint):
"""Get the canonical name for a symbol.
This is default implementation.
When user specified a name,
the user specified name will be used.
When user did not, we will automatically generate a
name based on hint string.
Parameters
----------
name : str or None
The name user specified.
hint : str
A hint string, which can be used to generate name.
Returns
-------
full_name : str
A canonical name for the user.
"""
if name:
return name
if hint not in self._counter:
self._counter[hint] = 0
name = '%s%d' % (hint, self._counter[hint])
self._counter[hint] += 1
return name
def __enter__(self):
self._old_manager = NameManager.current
NameManager.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_manager
NameManager.current = self._old_manager
class Prefix(NameManager):
"""A name manager that always attach a prefix to all names.
Examples
--------
>>> import nnvm as nn
>>> data = nn.symbol.Variable('data')
>>> with nn.name.Prefix('mynet_'):
net = nn.symbol.FullyConnected(data, num_hidden=10, name='fc1')
>>> net.list_arguments()
['data', 'mynet_fc1_weight', 'mynet_fc1_bias']
"""
def __init__(self, prefix):
super(Prefix, self).__init__()
self._prefix = prefix
def get(self, name, hint):
name = super(Prefix, self).get(name, hint)
return self._prefix + name
# initialize the default name manager
NameManager.current = NameManager()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utilities for testing and benchmarks"""
from __future__ import absolute_import as _abs
from .config import ctx_list
from .utils import create_workload
from . import mobilenet
from . import mobilenet_v2
from . import mlp
from . import resnet
from . import vgg
from . import densenet
from . import squeezenet
from . import inception_v3
from . import dcgan
from . import dqn
from . import check_computation
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Configuration about tests"""
from __future__ import absolute_import as _abs
import os
import tvm
def ctx_list():
"""Get context list for testcases"""
device_list = os.environ.get("NNVM_TEST_TARGETS", "")
device_list = (device_list.split(",") if device_list
else ["llvm", "cuda"])
device_list = set(device_list)
res = [(device, tvm.context(device, 0)) for device in device_list]
return [x for x in res if x[1].exist]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""
Symbol of the generator of DCGAN
Adopted from:
https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py
Reference:
Radford, Alec, Luke Metz, and Soumith Chintala.
"Unsupervised representation learning with deep convolutional generative adversarial networks."
arXiv preprint arXiv:1511.06434 (2015).
"""
from .. import symbol as sym
from . utils import create_workload
def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
"""a deconv layer that enlarges the feature map"""
target_shape = (oshape[-2], oshape[-1])
pad_y = (kshape[0] - 1) // 2
pad_x = (kshape[1] - 1) // 2
adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
net = sym.conv2d_transpose(data,
kernel_size=kshape,
strides=stride,
channels=oshape[0],
padding=(pad_y, pad_x),
output_padding=(adj_y, adj_x),
use_bias=False,
name=name)
return net
def deconv2d_bn_relu(data, prefix, **kwargs):
"""a block of deconv + batch norm + relu"""
eps = 1e-5 + 1e-12
net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
net = sym.batch_norm(net, epsilon=eps, name="%s_bn" % prefix)
net = sym.relu(net, name="%s_act" % prefix)
return net
def get_symbol(oshape, ngf=128, code=None):
"""get symbol of dcgan generator"""
assert oshape[-1] == 64, "Only support 64x64 image"
assert oshape[-2] == 64, "Only support 64x64 image"
code = sym.Variable("data") if code is None else code
net = sym.dense(code, name="g1", units=4*4*ngf*8, use_bias=False)
net = sym.relu(net)
# 4 x 4
net = sym.reshape(net, shape=(-1, ngf * 8, 4, 4))
# 8 x 8
net = deconv2d_bn_relu(
net, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4), prefix="g2")
# 16x16
net = deconv2d_bn_relu(
net, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4), prefix="g3")
# 32x32
net = deconv2d_bn_relu(
net, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4), prefix="g4")
# 64x64
net = deconv2d(
net, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv")
net = sym.tanh(net)
return net
def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, dtype="float32"):
"""Get benchmark workload for a DCGAN generator
Parameters
----------
batch_size : int
The batch size used in the model
oshape : tuple, optional
The shape of output image, layout="CHW"
ngf: int, optional
The number of final feature maps in the generator
random_len : int, optional
The length of random input
dtype : str, optional
The data type
Returns
-------
net : nnvm.symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(oshape=oshape, ngf=ngf)
return create_workload(net, batch_size, (random_len, ), dtype)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
DenseNet, load model from gluon model zoo
Reference:
Huang, Gao, et al. "Densely Connected Convolutional Networks." CVPR 2017
"""
from .utils import create_workload
from ..frontend.mxnet import _from_mxnet_impl
def get_workload(batch_size, num_classes=1000, num_layers=121, dtype="float32"):
"""Get benchmark workload for mobilenet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
num_layers : int, optional
The number of layers
dtype : str, optional
The data type
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model
image_shape = (1, 3, 224, 224)
block = get_model('densenet%d' % num_layers, classes=num_classes, pretrained=False)
data = mx.sym.Variable('data')
sym = block(data)
sym = mx.sym.SoftmaxOutput(sym)
net = _from_mxnet_impl(sym, {})
return create_workload(net, batch_size, image_shape[1:], dtype)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Symbol of Nature DQN
Reference:
Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529.
"""
from .. import symbol as sym
from . utils import create_workload
def get_symbol(num_actions=18):
"""get symbol of nature dqn"""
data = sym.Variable(name='data')
net = sym.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0),
channels=32, name='conv1')
net = sym.relu(net, name='relu1')
net = sym.conv2d(net, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0),
channels=64, name='conv2')
net = sym.relu(net, name='relu2')
net = sym.conv2d(net, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0),
channels=64, name='conv3')
net = sym.relu(net, name='relu3')
net = sym.flatten(net, name='flatten')
net = sym.dense(net, units=512, name='fc4')
net = sym.relu(net, name='relu4')
net = sym.dense(net, units=num_actions, name='fc5')
return net
def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"):
"""Get benchmark workload for a Deep Q Network
Parameters
----------
batch_size : int
The batch size used in the model
num_actions : int, optional
Number of actions
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_actions=num_actions)
return create_workload(net, batch_size, image_shape, dtype)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Initializer of parameters."""
import numpy as np
class Initializer(object):
"""The base class of an initializer."""
def __init__(self, **kwargs):
self._kwargs = kwargs
def __call__(self, desc, arr):
"""Initialize an array
Parameters
----------
desc : str
Initialization pattern descriptor.
arr : NDArray
The array to be initialized.
"""
if desc.endswith('weight'):
self._init_weight(desc, arr)
elif desc.endswith('bias'):
self._init_bias(desc, arr)
elif desc.endswith('gamma'):
self._init_gamma(desc, arr)
elif desc.endswith('beta'):
self._init_beta(desc, arr)
elif desc.endswith('mean'):
self._init_mean(desc, arr)
elif desc.endswith('var'):
self._init_var(desc, arr)
else:
self._init_default(desc, arr)
def _init_bias(self, _, arr):
arr[:] = 0.0
def _init_gamma(self, _, arr):
arr[:] = 1.0
def _init_beta(self, _, arr):
arr[:] = 0.0
def _init_mean(self, _, arr):
arr[:] = 0.0
def _init_var(self, _, arr):
arr[:] = 1.0
def _init_weight(self, name, arr):
"""Abstract method to Initialize weight."""
raise NotImplementedError("Must override it")
def _init_default(self, name, _):
raise ValueError(
'Unknown initialization pattern for %s. ' \
'Default initialization is now limited to '\
'"weight", "bias", "gamma" (1.0), and "beta" (0.0).' \
'Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern' % name)
class Xavier(Initializer):
""" "Xavier" initialization for weights
Parameters
----------
rnd_type: str, optional
Random generator type, can be ``'gaussian'`` or ``'uniform'``.
factor_type: str, optional
Can be ``'avg'``, ``'in'``, or ``'out'``.
magnitude: float, optional
Scale of random number.
"""
def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3):
super(Xavier, self).__init__(rnd_type=rnd_type,
factor_type=factor_type,
magnitude=magnitude)
self.rnd_type = rnd_type
self.factor_type = factor_type
self.magnitude = float(magnitude)
def _init_weight(self, name, arr):
shape = arr.shape
hw_scale = 1.
if len(shape) < 2:
raise ValueError('Xavier initializer cannot be applied to vector {0}. It requires at'
' least 2D.'.format(name))
if len(shape) > 2:
hw_scale = np.prod(shape[2:])
fan_in, fan_out = shape[1] * hw_scale, shape[0] * hw_scale
factor = 1.
if self.factor_type == "avg":
factor = (fan_in + fan_out) / 2.0
elif self.factor_type == "in":
factor = fan_in
elif self.factor_type == "out":
factor = fan_out
else:
raise ValueError("Incorrect factor type")
# Hack for mobilenet, because there is less connectivity
if "depthwise" in name:
factor = 3 * 3
scale = np.sqrt(self.magnitude / factor)
if self.rnd_type == "uniform":
arr[:] = np.random.uniform(-scale, scale, size=arr.shape)
else:
raise ValueError("Unknown random type")
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
a simple multilayer perceptron
"""
from .. import symbol as sym
from . utils import create_workload
def get_symbol(num_classes=1000):
data = sym.Variable('data')
data = sym.flatten(data=data)
fc1 = sym.dense(data=data, name='fc1', units=128)
act1 = sym.relu(data=fc1, name='relu1')
fc2 = sym.dense(data=act1, name='fc2', units=64)
act2 = sym.relu(data=fc2, name='relu2')
fc3 = sym.dense(data=act2, name='fc3', units=num_classes)
mlp = sym.softmax(data=fc3, name='softmax')
return mlp
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32"):
"""Get benchmark workload for a simple multilayer perceptron
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes)
return create_workload(net, batch_size, image_shape, dtype)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Helper utility to get mobilenet workload for testing."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
from .. import symbol as sym
from . utils import create_workload
def conv_block(data, name, channels,
kernel_size=(3, 3), strides=(1, 1), padding=(1, 1),
epsilon=1e-5):
"""Helper function to construct conv-bn-relu"""
# convolution + bn + relu
conv = sym.conv2d(data=data, channels=channels,
kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False,
layout="NCHW", name=name + "_conv")
bn = sym.batch_norm(data=conv, epsilon=epsilon, name=name + "_bn")
act = sym.relu(data=bn, name=name + "_relu")
return act
def separable_conv_block(data, name, depthwise_channels,
pointwise_channels, kernel_size=(3, 3),
downsample=False, padding=(1, 1),
epsilon=1e-5):
"""Helper function to get a separable conv block"""
if downsample:
strides = (2, 2)
else:
strides = (1, 1)
# depthwise convolution + bn + relu
conv1 = sym.conv2d(data=data, channels=depthwise_channels,
groups=depthwise_channels, kernel_size=kernel_size, strides=strides,
padding=padding, use_bias=False, layout="NCHW",
name=name + "_depthwise_conv1")
bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + "_bn1")
act1 = sym.relu(data=bn1, name=name + "_relu1")
# pointwise convolution + bn + relu
conv2 = sym.conv2d(data=act1, channels=pointwise_channels, kernel_size=(1, 1), strides=(1, 1),
padding=(0, 0), use_bias=False, layout="NCHW", name=name + "_conv2")
bn2 = sym.batch_norm(data=conv2, epsilon=epsilon, name=name + "_bn2")
act2 = sym.relu(data=bn2, name=name + "_relu2")
return act2
def mobile_net(num_classes=1000, alpha=1.0, is_shallow=False):
"""Function to construct a MobileNet"""
data = sym.Variable("data")
body = conv_block(data, "conv_block_1", int(32*alpha), strides=(2, 2))
body = separable_conv_block(body, "separable_conv_block_1",
int(32*alpha), int(64*alpha))
body = separable_conv_block(body, "separable_conv_block_2",
int(64*alpha), int(128*alpha), downsample=True)
body = separable_conv_block(body, "separable_conv_block_3",
int(128*alpha), int(128*alpha))
body = separable_conv_block(body, "separable_conv_block_4",
int(128*alpha), int(256*alpha), downsample=True)
body = separable_conv_block(body, "separable_conv_block_5",
int(256*alpha), int(256*alpha))
body = separable_conv_block(body, "separable_conv_block_6",
int(256*alpha), int(512*alpha), downsample=True)
if is_shallow:
body = separable_conv_block(body, "separable_conv_block_7",
int(512*alpha), int(1024*alpha), downsample=True)
body = separable_conv_block(body, "separable_conv_block_8",
int(1024*alpha), int(1024*alpha))
else:
for i in range(7, 12):
body = separable_conv_block(body, "separable_conv_block_%d" % i,
int(512*alpha), int(512*alpha))
body = separable_conv_block(body, "separable_conv_block_12",
int(512*alpha), int(1024*alpha), downsample=True)
body = separable_conv_block(body, "separable_conv_block_13",
int(1024*alpha), int(1024*alpha))
pool = sym.global_avg_pool2d(data=body, name="pool")
flatten = sym.flatten(data=pool, name="flatten")
fc = sym.dense(data=flatten, units=num_classes, use_bias=False, name="fc")
softmax = sym.softmax(data=fc, name="softmax")
return softmax
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32"):
"""Get benchmark workload for mobilenet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False)
return create_workload(net, batch_size, image_shape, dtype)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
MobileNetV2, load model from gluon model zoo
Reference:
Inverted Residuals and Linear Bottlenecks:
Mobile Networks for Classification, Detection and Segmentation
https://arxiv.org/abs/1801.04381
"""
from .utils import create_workload
from ..frontend.mxnet import _from_mxnet_impl
def get_workload(batch_size, num_classes=1000, multiplier=1.0, dtype="float32"):
"""Get benchmark workload for mobilenet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
multiplier : tuple, optional
The input image shape
dtype : str, optional
The data type
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
import mxnet as mx
from mxnet.gluon.model_zoo.vision.mobilenet import MobileNetV2
image_shape = (1, 3, 224, 224)
block = MobileNetV2(multiplier=multiplier, classes=num_classes)
data = mx.sym.Variable('data')
sym = block(data)
sym = mx.sym.SoftmaxOutput(sym)
net = _from_mxnet_impl(sym, {})
return create_workload(net, batch_size, image_shape[1:], dtype)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
'''
Adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py
Original author Wei Wu
Implemented the following paper:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
'''
# pylint: disable=unused-argument
from .. import symbol as sym
from . utils import create_workload
def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tuple
Stride used in convolution
dim_match : Boolean
True means channel number between input and output is the same,
otherwise means differ
name : str
Base name of the operators
"""
if bottle_neck:
bn1 = sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = sym.relu(data=bn1, name=name + '_relu1')
conv1 = sym.conv2d(
data=act1, channels=int(num_filter*0.25), kernel_size=(1, 1),
strides=stride, padding=(0, 0), use_bias=False, name=name + '_conv1')
bn2 = sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = sym.relu(data=bn2, name=name + '_relu2')
conv2 = sym.conv2d(
data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), use_bias=False, name=name + '_conv2')
bn3 = sym.batch_norm(data=conv2, epsilon=2e-5, name=name + '_bn3')
act3 = sym.relu(data=bn3, name=name + '_relu3')
conv3 = sym.conv2d(
data=act3, channels=num_filter, kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), use_bias=False, name=name + '_conv3')
if dim_match:
shortcut = data
else:
shortcut = sym.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, use_bias=False, name=name+'_sc')
return sym.elemwise_add(conv3, shortcut)
else:
bn1 = sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = sym.relu(data=bn1, name=name + '_relu1')
conv1 = sym.conv2d(
data=act1, channels=num_filter, kernel_size=(3, 3),
strides=stride, padding=(1, 1), use_bias=False, name=name + '_conv1')
bn2 = sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = sym.relu(data=bn2, name=name + '_relu2')
conv2 = sym.conv2d(
data=act2, channels=num_filter, kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), use_bias=False, name=name + '_conv2')
if dim_match:
shortcut = data
else:
shortcut = sym.conv2d(
data=act1, channels=num_filter, kernel_size=(1, 1),
strides=stride, use_bias=False, name=name+'_sc')
return sym.elemwise_add(conv2, shortcut)
def resnet(units, num_stages, filter_list, num_classes, image_shape,
bottle_neck=True):
"""Return ResNet symbol of
Parameters
----------
units : list
Number of units in each stage
num_stages : int
Number of stage
filter_list : list
Channel size of each stage
num_classes : int
Ouput size of symbol
dataset : str
Dataset type, only cifar10 and imagenet supports
"""
num_unit = len(units)
assert num_unit == num_stages
data = sym.Variable(name='data')
data = sym.batch_norm(data=data, epsilon=2e-5, scale=False, name='bn_data')
(_, height, _) = image_shape
if height <= 32: # such as cifar10
body = sym.conv2d(
data=data, channels=filter_list[0], kernel_size=(3, 3),
strides=(1, 1), padding=(1, 1), use_bias=False, name="conv0")
else: # often expected to be 224 such as imagenet
body = sym.conv2d(
data=data, channels=filter_list[0], kernel_size=(7, 7),
strides=(2, 2), padding=(3, 3), use_bias=False, name="conv0")
body = sym.batch_norm(data=body, epsilon=2e-5, name='bn0')
body = sym.relu(data=body, name='relu0')
body = sym.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1))
for i in range(num_stages):
body = residual_unit(
body, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2),
False, name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck)
for j in range(units[i]-1):
body = residual_unit(
body, filter_list[i+1], (1, 1), True,
name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck)
bn1 = sym.batch_norm(data=body, epsilon=2e-5, name='bn1')
relu1 = sym.relu(data=bn1, name='relu1')
# Although kernel is not used here when global_pool=True, we should put one
pool1 = sym.global_avg_pool2d(data=relu1, name='pool1')
flat = sym.flatten(data=pool1)
fc1 = sym.dense(data=flat, units=num_classes, name='fc1')
return sym.softmax(data=fc1, name='softmax')
def get_symbol(num_classes, num_layers=50, image_shape=(3, 224, 224), **kwargs):
"""
Adapted from https://github.com/tornadomeet/ResNet/blob/master/train_resnet.py
Original author Wei Wu
"""
(_, height, _) = image_shape
if height <= 28:
num_stages = 3
if (num_layers-2) % 9 == 0 and num_layers >= 164:
per_unit = [(num_layers-2)//9]
filter_list = [16, 64, 128, 256]
bottle_neck = True
elif (num_layers-2) % 6 == 0 and num_layers < 164:
per_unit = [(num_layers-2)//6]
filter_list = [16, 16, 32, 64]
bottle_neck = False
else:
raise ValueError("no experiments done on num_layers {}".format(num_layers))
units = per_unit * num_stages
else:
if num_layers >= 50:
filter_list = [64, 256, 512, 1024, 2048]
bottle_neck = True
else:
filter_list = [64, 64, 128, 256, 512]
bottle_neck = False
num_stages = 4
if num_layers == 18:
units = [2, 2, 2, 2]
elif num_layers == 34:
units = [3, 4, 6, 3]
elif num_layers == 50:
units = [3, 4, 6, 3]
elif num_layers == 101:
units = [3, 4, 23, 3]
elif num_layers == 152:
units = [3, 8, 36, 3]
elif num_layers == 200:
units = [3, 24, 36, 3]
elif num_layers == 269:
units = [3, 30, 48, 8]
else:
raise ValueError("no experiments done on num_layers {}".format(num_layers))
return resnet(units=units,
num_stages=num_stages,
filter_list=filter_list,
num_classes=num_classes,
image_shape=image_shape,
bottle_neck=bottle_neck)
def get_workload(batch_size=1, num_classes=1000, num_layers=18,
image_shape=(3, 224, 224), dtype="float32", **kwargs):
"""Get benchmark workload for resnet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
num_layers : int, optional
Number of layers
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
kwargs : dict
Extra arguments
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, num_layers=num_layers,
image_shape=image_shape, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=unused-argument
"""
Symbol of SqueezeNet
Reference:
Iandola, Forrest N., et al.
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
"""
from .. import symbol as sym
from . utils import create_workload
# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
net = _make_fire_conv(net, squeeze_channels, 1, 0)
left = _make_fire_conv(net, expand1x1_channels, 1, 0)
right = _make_fire_conv(net, expand3x3_channels, 3, 1)
# NOTE : Assume NCHW layout here
net = sym.concatenate(left, right, axis=1)
return net
def _make_fire_conv(net, channels, kernel_size, padding=0):
net = sym.conv2d(net, channels=channels, kernel_size=(kernel_size, kernel_size),
padding=(padding, padding))
net = sym.relu(net)
return net
# Net
def get_symbol(num_classes, version, **kwargs):
"""Get symbol of SqueezeNet
Parameters
----------
num_classes: int
The number of classification results
version : str, optional
"1.0" or "1.1" of SqueezeNet
"""
assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
net = sym.Variable("data")
if version == '1.0':
net = sym.conv2d(net, channels=96, kernel_size=(7, 7), strides=(2, 2), padding=(3, 3))
net = sym.relu(net)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 32, 128, 128)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 64, 256, 256)
else:
net = sym.conv2d(net, channels=64, kernel_size=(3, 3), strides=(2, 2), padding=(1, 1))
net = sym.relu(net)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256)
net = sym.dropout(net, rate=0.5)
net = sym.conv2d(net, channels=num_classes, kernel_size=(1, 1))
net = sym.relu(net)
net = sym.global_avg_pool2d(net)
net = sym.flatten(net)
return sym.softmax(net)
def get_workload(batch_size=1, num_classes=1000, version='1.0',
image_shape=(3, 224, 224), dtype="float32", **kwargs):
"""Get benchmark workload for SqueezeNet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
version : str, optional
"1.0" or "1.1" of SqueezeNet
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
kwargs : dict
Extra arguments
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, version=version, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Helper utility to create common workload for testing."""
from __future__ import absolute_import as _abs
import numpy as np
import tvm
from ..compiler import graph_util
from ..import graph
from . init import Xavier
def create_workload(net, batch_size, image_shape=(3, 224, 224),
dtype="float32", initializer=None, seed=0):
"""Helper function to create benchmark workload for input network
Parameters
----------
net : nnvm.Symbol
The selected network symbol to use
batch_size : int
The batch size used in the model
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
initializer : Initializer
The initializer used
seed : int
The seed used in initialization.
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
if image_shape is None:
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
params = {}
g = graph.create(net)
input_shapes, _ = graph_util.infer_shape(g, data=data_shape)
shape_dict = dict(zip(g.index.input_names, input_shapes))
np.random.seed(seed)
initializer = initializer if initializer else Xavier()
for k, v in shape_dict.items():
if k == "data":
continue
init_value = np.zeros(v).astype(dtype)
initializer(k, init_value)
params[k] = tvm.nd.array(init_value, ctx=tvm.cpu(0))
return net, params
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""References:
Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for
large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
"""
from .. import symbol as sym
from . utils import create_workload
def get_feature(internel_layer, layers, filters, batch_norm=False):
"""Get VGG feature body as stacks of convoltions."""
for i, num in enumerate(layers):
for j in range(num):
internel_layer = sym.conv2d(
data=internel_layer, kernel_size=(3, 3), padding=(1, 1),
channels=filters[i], name="conv%s_%s"%(i + 1, j + 1))
if batch_norm:
internel_layer = sym.batch_norm(
data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
internel_layer = sym.relu(data=internel_layer, name="relu%s_%s" %(i + 1, j + 1))
internel_layer = sym.max_pool2d(
data=internel_layer, pool_size=(2, 2), strides=(2, 2), name="pool%s"%(i + 1))
return internel_layer
def get_classifier(input_data, num_classes):
"""Get VGG classifier layers as fc layers."""
flatten = sym.flatten(data=input_data, name="flatten")
fc6 = sym.dense(data=flatten, units=4096, name="fc6")
relu6 = sym.relu(data=fc6, name="relu6")
drop6 = sym.dropout(data=relu6, rate=0.5, name="drop6")
fc7 = sym.dense(data=drop6, units=4096, name="fc7")
relu7 = sym.relu(data=fc7, name="relu7")
drop7 = sym.dropout(data=relu7, rate=0.5, name="drop7")
fc8 = sym.dense(data=drop7, units=num_classes, name="fc8")
return fc8
def get_symbol(num_classes, num_layers=11, batch_norm=False):
"""
Parameters
----------
num_classes : int, default 1000
Number of classification classes.
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
batch_norm : bool, default False
Use batch normalization.
"""
vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
if num_layers not in vgg_spec:
raise ValueError("Invalide num_layers {}. Choices are 11,13,16,19.".format(num_layers))
layers, filters = vgg_spec[num_layers]
data = sym.Variable(name="data")
feature = get_feature(data, layers, filters, batch_norm)
classifier = get_classifier(feature, num_classes)
symbol = sym.softmax(data=classifier, name='softmax')
return symbol
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224),
dtype="float32", **kwargs):
"""Get benchmark workload for VGG nets.
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of claseses
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type
kwargs : dict
Extra arguments
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tensor operator property registry
Provide information to lower and schedule tensor operators.
"""
from .attr_dict import AttrDict
from . import tensor
from . import nn
from . import transform
from . import reduction
from . import vision
from . import image
from .registry import OpPattern
from .registry import register_compute, register_schedule, register_pattern
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Attr dictionary object used by schedule functions"""
import tvm
_dict_get = tvm.get_global_func("nnvm.compiler._dict_get")
_dict_size = tvm.get_global_func("nnvm.compiler._dict_size")
_dict_keys = tvm.get_global_func("nnvm.compiler._dict_keys")
class AttrDict(object):
"""Attribute dictionary in nnvm.
Used by python registration of compute and schedule function.
AttrDict is passed as the first argument to schedule and compute function.
"""
_tvm_tcode = 18
def __init__(self, handle):
self.handle = handle
def __del__(self):
tvm.nd.free_extension_handle(self.handle, 18)
@property
def _tvm_handle(self):
return self.handle.value
def __getitem__(self, key):
return _dict_get(self, key)
def keys(self):
"""Get list of keys in the dict.
Returns
-------
keys : list of str
List of keys
"""
return [x.value for x in _dict_keys(self)]
def get_int_tuple(self, key):
"""Get tuple of integer from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
tuple : tuple of int
The result tuple
"""
return tuple(int(x) for x in self[key][1:-1].split(",") if x)
def get_int_pair_tuple(self, key):
"""Get tuple of integer pairs from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
tuple : tuple of int pairs
The result tuple
"""
flat = [int(x.strip(' [] ')) for x in self[key][1:-1].split(",")]
return tuple((flat[i], flat[i+1]) for i in range(0, len(flat), 2))
def get_int(self, key):
"""Get integer from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
value : int
The result value
"""
return int(self[key])
def get_float_tuple(self, key):
"""Get tuple of float from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
tuple : tuple of float
The result tuple
"""
return tuple(float(x) for x in self[key][1:-1].split(",") if x)
def get_float(self, key):
"""Get float from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
value : float
The result value
"""
return float(self[key])
def get_bool(self, key):
"""Get bool from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
value : bool
The result value
"""
lowercase = self[key].lower()
if lowercase == "1":
return True
if lowercase == "0":
return False
if lowercase == "true":
return True
if lowercase == "false":
return False
raise ValueError("Wrong bool format for key %s" % key)
def get_str(self, key):
"""Get string from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
value : str
The result value
"""
return self[key]
def __repr__(self):
return str({k : self[k] for k in self.keys()})
tvm.register_extension(AttrDict, AttrDict)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Definition of image ops"""
from __future__ import absolute_import
import tvm
import topi
from . import registry as reg
from .registry import OpPattern
# resize
@reg.register_schedule("resize")
def schedule_resize(_, outs, target):
"""Schedule definition of resize"""
with tvm.target.create(target):
return topi.generic.schedule_injective(outs)
reg.register_pattern("resize", OpPattern.INJECTIVE)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Reduction ops"""
from __future__ import absolute_import
import tvm
import topi
import topi.cuda
from . import registry as reg
from .registry import OpPattern
def _schedule_reduce(_, outs, target):
"""Generic schedule for reduce"""
with tvm.target.create(target):
return topi.generic.schedule_reduce(outs)
_fschedule_reduce = tvm.convert(_schedule_reduce)
def _compute_reduce(f):
"""auxiliary function"""
def _compute(attrs, inputs, out_info):
axis = attrs.get_int_tuple("axis")
keepdims = attrs.get_bool("keepdims")
if axis:
return f(inputs[0], axis=axis, keepdims=keepdims)
return f(inputs[0], keepdims=keepdims)
return _compute
# sum
reg.register_pattern("sum", OpPattern.COMM_REDUCE)
reg.register_schedule("sum", _fschedule_reduce)
# max
reg.register_pattern("max", OpPattern.COMM_REDUCE)
reg.register_schedule("max", _fschedule_reduce)
# min
reg.register_pattern("min", OpPattern.COMM_REDUCE)
reg.register_schedule("min", _fschedule_reduce)
# collapse sum
reg.register_pattern("collapse_sum", OpPattern.COMM_REDUCE)
reg.register_schedule("collapse_sum", _fschedule_reduce)
# argmax
reg.register_pattern("argmax", OpPattern.COMM_REDUCE)
reg.register_schedule("argmax", _fschedule_reduce)
# argmin
reg.register_pattern("argmin", OpPattern.COMM_REDUCE)
reg.register_schedule("argmin", _fschedule_reduce)
# mean
reg.register_pattern("mean", OpPattern.COMM_REDUCE)
reg.register_schedule("mean", _fschedule_reduce)
# product
reg.register_pattern("prod", OpPattern.COMM_REDUCE)
reg.register_schedule("prod", _fschedule_reduce)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Information registry to register operator information for compiler"""
import tvm
class OpPattern(object):
"""Operator generic patterns
See Also
--------
top.tag : Contains explanation of the tag type.
"""
# Elementwise operator
ELEMWISE = 0
# Broadcast operator
BROADCAST = 1
# Injective mapping
INJECTIVE = 2
# Comunication
COMM_REDUCE = 3
# Complex op, can still fuse ewise into it
OUT_ELEMWISE_FUSABLE = 4
# Not fusable opaque op
OPAQUE = 8
_register_compute = tvm.get_global_func("nnvm._register_compute")
_register_schedule = tvm.get_global_func("nnvm._register_schedule")
_register_pattern = tvm.get_global_func("nnvm._register_pattern")
_register_alter_op_layout = tvm.get_global_func("nnvm.compiler._register_alter_op_layout")
def register_compute(op_name, f=None, level=10):
"""Register compute function for operator
Parameters
----------
op_name : str
The name of operator
f : function
The schedule function
level : int
The priority level
Returns
-------
fregister : function
Register function if f is not specified.
"""
def register(myf):
"""internal register function"""
_register_compute(op_name, myf, level)
return myf
return register(f) if f else register
def register_schedule(op_name, f=None, level=10):
"""Register schedule function for operator
Parameters
----------
op_name : str
The name of operator
f : function
The schedule function
level : int
The priority level
Returns
-------
fregister : function
Register function if f is not specified.
"""
def register(myf):
"""internal register function"""
_register_schedule(op_name, myf, level)
return myf
return register(f) if f else register
def register_pattern(op_name, pattern, level=10):
"""Register pattern code for operator
Parameters
----------
op_name : str
The name of operator
pattern : int
The pattern code.
level : int
The priority level
"""
_register_pattern(op_name, pattern, level)
def register_alter_op_layout(op_name, f=None, level=10):
"""Register alter layout function for operator
Parameters
----------
op_name : str
The name of operator
f : function
The schedule function
level : int
The priority level
Returns
-------
fregister : function
Register function if f is not specified.
"""
def register(myf):
"""internal register function"""
_register_alter_op_layout(op_name, myf, level)
return myf
return register(f) if f else register
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment