Commit 9f8fcfc9 by Yizhi Liu Committed by Tianqi Chen

General Layout Support (#447)

parent fc7e9cd2
/*!
* Copyright (c) 2017 by Contributors
* \file contrib_op_param.h
* \brief Additional parameters for compiler optimized operators.
*/
#ifndef NNVM_COMPILER_CONTRIB_OP_PARAM_H_
#define NNVM_COMPILER_CONTRIB_OP_PARAM_H_
#include <dmlc/parameter.h>
#include <string>
namespace nnvm {
namespace compiler {
/*! \brief Parameters of layout transform operator */
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
std::string src_layout;
std::string dst_layout;
DMLC_DECLARE_PARAMETER(LayoutTransformParam) {
DMLC_DECLARE_FIELD(src_layout);
DMLC_DECLARE_FIELD(dst_layout);
}
};
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_CONTRIB_OP_PARAM_H_
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <nnvm/graph.h> #include <nnvm/graph.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include "packed_func_ext.h"
namespace nnvm { namespace nnvm {
namespace compiler { namespace compiler {
...@@ -73,19 +74,17 @@ using FTVMSchedule = std::function< ...@@ -73,19 +74,17 @@ using FTVMSchedule = std::function<
const Array<Tensor>& outs, const Array<Tensor>& outs,
const std::string& target)>; const std::string& target)>;
/*! \brief Layout Information about an entry */
using TLayoutInfo = std::string;
/*! /*!
* \brief The producer consumer function of node layout * \brief Modify the op node to alter its input layout.
* \param attrs The attribute of the node. * it is invoked in AlterOpLayout pass.
* \param ilayouts The input layouts that the node request. * \param attrs The attribute of the original node.
* \param olayouts The output layouts that the node produce. * \param inputs The input symbols of the original node.
* \return bool The success flag. * \param tinfos The inferred shape and dtype of the inputs.
*/ */
using FTVMLayoutRequest = std::function<bool (const NodeAttrs& attrs, using FTVMAlterOpLayout = std::function<
std::vector<TLayoutInfo> *ilayouts, Symbol(const NodeAttrs& attrs,
std::vector<TLayoutInfo> *olayouts)>; const Symbol& inputs,
const Array<Tensor>& tinfos)>;
/*! /*!
* \brief Transform from normal operator to vectorized operator * \brief Transform from normal operator to vectorized operator
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <nnvm/graph.h> #include <nnvm/graph.h>
#include <nnvm/symbolic.h> #include <nnvm/symbolic.h>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
namespace nnvm { namespace nnvm {
...@@ -52,6 +53,7 @@ template<> ...@@ -52,6 +53,7 @@ template<>
struct extension_class_info<nnvm::compiler::AttrDict> { struct extension_class_info<nnvm::compiler::AttrDict> {
static const int code = 18; static const int code = 18;
}; };
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_ #endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include "./tuple.h" #include "./tuple.h"
#include "./layout.h"
namespace nnvm { namespace nnvm {
...@@ -46,7 +47,7 @@ using ShapeVector = std::vector<TShape>; ...@@ -46,7 +47,7 @@ using ShapeVector = std::vector<TShape>;
* \code * \code
* Graph g = ApplyPass(src_graph, "InferType"); * Graph g = ApplyPass(src_graph, "InferType");
* const DTypeVector& types = g.GetAttr<DTypeVector>("dtype"); * const DTypeVector& types = g.GetAttr<DTypeVector>("dtype");
* // get shape by entry id * // get type by entry id
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)]; * int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
* \endcode * \endcode
* *
...@@ -55,6 +56,21 @@ using ShapeVector = std::vector<TShape>; ...@@ -55,6 +56,21 @@ using ShapeVector = std::vector<TShape>;
using DTypeVector = std::vector<int>; using DTypeVector = std::vector<int>;
/*! /*!
* \brief The result holder of layout of each NodeEntry in the graph.
* \note Stored under graph.attrs["layout"], provided by Pass "InferType"
*
* \code
* Graph g = ApplyPass(src_graph, "LayoutTransform");
* const LayoutVector& layouts = g.GetAttr<LayoutVector>("layout");
* // get layout by entry id
* int entry_layout = layouts[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferLayout
*/
using LayoutVector = std::vector<Layout>;
/*!
* \brief The result holder of device of each operator in the graph. * \brief The result holder of device of each operator in the graph.
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice" * \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice"
* *
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "./base.h" #include "./base.h"
#include "./node.h" #include "./node.h"
#include "./tuple.h" #include "./tuple.h"
#include "./layout.h"
namespace nnvm { namespace nnvm {
...@@ -176,6 +177,31 @@ using FSetInputVarAttrOnCompose = std::function<void( ...@@ -176,6 +177,31 @@ using FSetInputVarAttrOnCompose = std::function<void(
NodePtr var, NodePtr var,
const int index)>; const int index)>;
/*!
* \brief Inference function of node layout. See \p Layout for layout convention
* \param attrs The attribute of the node.
* \param ilayouts Given the input layouts produced by ancestor nodes,
* it should be filled by layouts that the node requests.
* If the requested layout is different from what ancestor produces,
* a __layout_transform__ operator will be inserted automatically.
* \param last_ilayouts The input layouts requested by the node
* at the last infer pass (if any).
* This can be useful when an operator wants to keep
* the input layout the same as the original one.
* For example, after the pass of AlterOpLayout,
* transpose(input, axis=[1, 2, 3, 0]) may receive an input of NCHW16c layout,
* with which it cannot calculate with axis=[1, 2, 3, 0].
* Last input layouts allow it to know what the layout it originally inferred,
* i.e., the layout in the imported model.
* \param olayouts Inferred output layouts.
* \return success flag.
*/
using FInferLayout = std::function<bool(
const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts)>;
} // namespace nnvm } // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_ #endif // NNVM_OP_ATTR_TYPES_H_
...@@ -9,23 +9,12 @@ ...@@ -9,23 +9,12 @@
#include <dmlc/base.h> #include <dmlc/base.h>
#include <dmlc/parameter.h> #include <dmlc/parameter.h>
#include <nnvm/tuple.h> #include <nnvm/tuple.h>
#include <nnvm/layout.h>
#include <string>
namespace nnvm { namespace nnvm {
namespace top { namespace top {
// Layout flag in spatial conv and pooling.
enum LayoutFlag {
kNCHW,
kNHWC,
kCHWN,
kNCW,
kNWC,
kCWN,
kNCDHW,
kNDHWC,
kCDHWN
};
struct DenseParam : public dmlc::Parameter<DenseParam> { struct DenseParam : public dmlc::Parameter<DenseParam> {
int units; int units;
bool use_bias; bool use_bias;
...@@ -130,7 +119,9 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> { ...@@ -130,7 +119,9 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {
TShape padding; TShape padding;
TShape dilation; TShape dilation;
int groups; int groups;
int layout; std::string layout;
std::string kernel_layout;
std::string out_layout;
bool use_bias; bool use_bias;
DMLC_DECLARE_PARAMETER(Conv2DParam) { DMLC_DECLARE_PARAMETER(Conv2DParam) {
...@@ -152,14 +143,19 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> { ...@@ -152,14 +143,19 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {
"At groups=2, the operation becomes equivalent to having two convolution" "At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing" "layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated."); "half the output channels, and both subsequently concatenated.");
DMLC_DECLARE_FIELD(layout) DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.add_enum("NCHW", kNCHW) .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and" "dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions."); "'W' dimensions.");
DMLC_DECLARE_FIELD(out_layout).set_default("__undef__")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
DMLC_DECLARE_FIELD(use_bias).set_default(true) DMLC_DECLARE_FIELD(use_bias).set_default(true)
.describe("Whether the layer uses a bias vector."); .describe("Whether the layer uses a bias vector.");
} }
...@@ -178,7 +174,8 @@ struct Conv2DTransposeParam : public dmlc::Parameter<Conv2DTransposeParam> { ...@@ -178,7 +174,8 @@ struct Conv2DTransposeParam : public dmlc::Parameter<Conv2DTransposeParam> {
TShape output_padding; TShape output_padding;
TShape dilation; TShape dilation;
int groups; int groups;
int layout; std::string layout;
std::string kernel_layout;
bool use_bias; bool use_bias;
DMLC_DECLARE_PARAMETER(Conv2DTransposeParam) { DMLC_DECLARE_PARAMETER(Conv2DTransposeParam) {
...@@ -202,14 +199,15 @@ struct Conv2DTransposeParam : public dmlc::Parameter<Conv2DTransposeParam> { ...@@ -202,14 +199,15 @@ struct Conv2DTransposeParam : public dmlc::Parameter<Conv2DTransposeParam> {
"At groups=2, the operation becomes equivalent to having two convolution" "At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing" "layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated."); "half the output channels, and both subsequently concatenated.");
DMLC_DECLARE_FIELD(layout) DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.add_enum("NCHW", kNCHW) .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc."
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and" "dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions."); "'W' dimensions.");
DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
DMLC_DECLARE_FIELD(use_bias).set_default(true) DMLC_DECLARE_FIELD(use_bias).set_default(true)
.describe("Whether the layer uses a bias vector."); .describe("Whether the layer uses a bias vector.");
} }
...@@ -224,7 +222,7 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> { ...@@ -224,7 +222,7 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
TShape pool_size; TShape pool_size;
TShape strides; TShape strides;
TShape padding; TShape padding;
int layout; std::string layout;
bool ceil_mode; bool ceil_mode;
DMLC_DECLARE_PARAMETER(Pool2DParam) { DMLC_DECLARE_PARAMETER(Pool2DParam) {
...@@ -235,10 +233,7 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> { ...@@ -235,10 +233,7 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded" .describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points"); "on both sides for padding number of points");
DMLC_DECLARE_FIELD(layout) DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.add_enum("NCHW", kNCHW)
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and" "dimensions respectively. Convolution is applied on the 'H' and"
...@@ -250,13 +245,10 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> { ...@@ -250,13 +245,10 @@ struct Pool2DParam : public dmlc::Parameter<Pool2DParam> {
struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> { struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
int layout; std::string layout;
DMLC_DECLARE_PARAMETER(GlobalPool2DParam) { DMLC_DECLARE_PARAMETER(GlobalPool2DParam) {
DMLC_DECLARE_FIELD(layout) DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.add_enum("NCHW", kNCHW)
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and" "dimensions respectively. Convolution is applied on the 'H' and"
...@@ -266,15 +258,13 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> { ...@@ -266,15 +258,13 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> { struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
int scale; int scale;
int layout; std::string layout;
DMLC_DECLARE_PARAMETER(UpSamplingParam) { DMLC_DECLARE_PARAMETER(UpSamplingParam) {
DMLC_DECLARE_FIELD(scale) DMLC_DECLARE_FIELD(scale)
.describe("upsampling scaling factor"); .describe("upsampling scaling factor");
DMLC_DECLARE_FIELD(layout) DMLC_DECLARE_FIELD(layout)
.add_enum("NCHW", kNCHW) .set_default("NCHW")
.add_enum("NHWC", kNHWC)
.set_default(kNCHW)
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and" "dimensions respectively. Convolution is applied on the 'H' and"
...@@ -282,6 +272,18 @@ struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> { ...@@ -282,6 +272,18 @@ struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
} }
}; };
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
std::string src_layout;
std::string dst_layout;
DMLC_DECLARE_PARAMETER(LayoutTransformParam) {
DMLC_DECLARE_FIELD(src_layout).set_default("__undef__")
.describe("Dimension ordering of data");
DMLC_DECLARE_FIELD(dst_layout).set_default("__undef__")
.describe("Dimension ordering of data.");
}
};
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
......
...@@ -211,12 +211,15 @@ def _init_symbol_module(symbol_class, root_namespace): ...@@ -211,12 +211,15 @@ def _init_symbol_module(symbol_class, root_namespace):
op_names.append(py_str(plist[i])) op_names.append(py_str(plist[i]))
module_obj = sys.modules["%s.symbol" % root_namespace] module_obj = sys.modules["%s.symbol" % root_namespace]
module_obj_contrib = sys.modules["%s.contrib" % root_namespace]
module_internal = sys.modules["%s._symbol_internal" % root_namespace] module_internal = sys.modules["%s._symbol_internal" % root_namespace]
for name in op_names: for name in op_names:
hdl = OpHandle() hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
function = _make_atomic_symbol_function(hdl, name) function = _make_atomic_symbol_function(hdl, name)
if function.__name__.startswith('_'): if function.__name__.startswith('_contrib_'):
setattr(module_obj_contrib, function.__name__.split('_contrib_')[1], function)
elif function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function) setattr(module_internal, function.__name__, function)
setattr(module_obj, function.__name__, function) setattr(module_obj, function.__name__, function)
else: else:
......
...@@ -15,7 +15,8 @@ OPT_PASS_LEVEL = { ...@@ -15,7 +15,8 @@ OPT_PASS_LEVEL = {
"SimplifyInference": 0, "SimplifyInference": 0,
"PrecomputePrune": 2, "PrecomputePrune": 2,
"OpFusion": 1, "OpFusion": 1,
"FoldScaleAxis": 3 "FoldScaleAxis": 3,
"AlterOpLayout": 3,
} }
# List of optimization pass and level when switch on # List of optimization pass and level when switch on
...@@ -139,7 +140,7 @@ def _update_shape_dtype(shape, dtype, params): ...@@ -139,7 +140,7 @@ def _update_shape_dtype(shape, dtype, params):
return shape, dtype return shape, dtype
def optimize(graph, shape, dtype="float32"): def optimize(graph, shape, dtype="float32", layout=None):
"""Perform target and parameter invariant graph optimization. """Perform target and parameter invariant graph optimization.
This is an advanced function that usually do not need to be called. This is an advanced function that usually do not need to be called.
...@@ -157,6 +158,18 @@ def optimize(graph, shape, dtype="float32"): ...@@ -157,6 +158,18 @@ def optimize(graph, shape, dtype="float32"):
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
cfg = BuildConfig.current cfg = BuildConfig.current
if cfg.pass_enabled("AlterOpLayout"):
layout = layout if layout else {}
graph = graph_attr.set_layout_inputs(graph, layout)
graph = graph.apply(["CorrectLayout"])
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph = graph.apply(["InferShape", "InferType", "AlterOpLayout"])
graph = graph_attr.set_layout_inputs(graph, layout)
graph = graph.apply(["CorrectLayout"])
if cfg.pass_enabled("SimplifyInference"): if cfg.pass_enabled("SimplifyInference"):
graph = graph_attr.set_shape_inputs(graph, shape) graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "SimplifyInference"]) graph = graph.apply(["InferShape", "SimplifyInference"])
...@@ -167,7 +180,8 @@ def optimize(graph, shape, dtype="float32"): ...@@ -167,7 +180,8 @@ def optimize(graph, shape, dtype="float32"):
return graph return graph
def build(graph, target=None, shape=None, dtype="float32", params=None, target_host=None): def build(graph, target=None, shape=None, dtype="float32",
params=None, target_host=None, layout=None):
"""Build graph into runtime library. """Build graph into runtime library.
The build function will optimize the graph and do the compilation. The build function will optimize the graph and do the compilation.
...@@ -204,8 +218,8 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h ...@@ -204,8 +218,8 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
By default, llvm is used if it is enabled, By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used. otherwise a stackvm intepreter is used.
initialize : bool, optional layout : dict of str to str or str optional
Whether to initialize variables in global dict _all_var_init. The input layout
Returns Returns
------- -------
...@@ -230,6 +244,15 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h ...@@ -230,6 +244,15 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
cfg = BuildConfig.current cfg = BuildConfig.current
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph) graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
shape, dtype = _update_shape_dtype(shape, dtype, params) shape, dtype = _update_shape_dtype(shape, dtype, params)
# correct layout if necessary
layout = layout if layout else {}
graph = graph_attr.set_layout_inputs(graph, layout)
graph = graph.apply("CorrectLayout")
index = graph.index
layouts = graph.json_attr("layout")
layout = {x : layouts[index.entry_id(x)] for x in index.input_names}
# Initial pass do shape type inference # Initial pass do shape type inference
ishape, _ = graph_util.infer_shape(graph, **shape) ishape, _ = graph_util.infer_shape(graph, **shape)
shape.update(zip(graph.index.input_names, ishape)) shape.update(zip(graph.index.input_names, ishape))
...@@ -241,13 +264,14 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h ...@@ -241,13 +264,14 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
if _all_var_init: if _all_var_init:
init_var = initialize_variables(shape, dtype) init_var = initialize_variables(shape, dtype)
# Apply optimization # Apply optimization
graph = optimize(graph, shape, dtype) graph = optimize(graph, shape, dtype, layout)
# Precompute prune # Precompute prune
if params and cfg.pass_enabled("PrecomputePrune"): if params and cfg.pass_enabled("PrecomputePrune"):
graph, params = precompute_prune(graph, params) graph, params = precompute_prune(graph, params)
shape, dtype = _update_shape_dtype(shape, dtype, params) shape, dtype = _update_shape_dtype(shape, dtype, params)
# Operator Fusion and generation # Operator Fusion and generation
graph = graph_attr.set_shape_inputs(graph, shape) graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
graph = graph_attr.set_dtype_inputs(graph, dtype) graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", str(target), "str") graph._set_json_attr("target", str(target), "str")
if target_host is not None: if target_host is not None:
......
...@@ -96,11 +96,22 @@ def set_layout_inputs(g, layout): ...@@ -96,11 +96,22 @@ def set_layout_inputs(g, layout):
Returns Returns
------- -------
g : Graph g : Graph
The updated graph with updated dtype. The updated graph with updated layout.
""" """
list_shape = [ if isinstance(layout, dict):
layout.get(name, "default") for name in g.index.input_names] list_layout = [
g._set_json_attr("layout_inputs", list_shape, 'list_str') layout.get(name, "__undef__") for name in g.index.input_names]
elif isinstance(layout, str):
list_layout = ["__undef__"] * len(g.index.input_names)
list_layout[0] = layout
else:
raise ValueError("Input layout must be str or dict")
last_inferred_layouts = g.json_attr("layout")
if last_inferred_layouts:
input_layout = [last_inferred_layouts[g.index.entry_id(x)] for x in g.index.input_names]
for i, layout_stored in enumerate(input_layout):
list_layout[i] = list_layout[i] if list_layout[i] != '__undef__' else layout_stored
g._set_json_attr("layout_inputs", list_layout, 'list_layout')
return g return g
_move_out_module = tvm.get_global_func("nnvm.graph._move_module") _move_out_module = tvm.get_global_func("nnvm.graph._move_module")
......
"""Module space to register contrib functions. Leave empty"""
...@@ -86,6 +86,10 @@ def _conv2d(inputs, attrs): ...@@ -86,6 +86,10 @@ def _conv2d(inputs, attrs):
layout = attrs.get('layout', 'NCHW') layout = attrs.get('layout', 'NCHW')
if layout not in ['NCHW', 'NHWC']: if layout not in ['NCHW', 'NHWC']:
_raise_not_supported('layout: ' + layout, 'conv2d') _raise_not_supported('layout: ' + layout, 'conv2d')
if 'kernel_layout' in attrs:
kernel_layout = attrs['kernel_layout']
else:
kernel_layout = 'HWIO' if layout == 'NHWC' else 'OIHW'
op_name, new_attrs = 'conv2d', {} op_name, new_attrs = 'conv2d', {}
new_attrs['channels'] = _required_attr(attrs, 'num_filter') new_attrs['channels'] = _required_attr(attrs, 'num_filter')
new_attrs['kernel_size'] = kernel new_attrs['kernel_size'] = kernel
...@@ -94,6 +98,7 @@ def _conv2d(inputs, attrs): ...@@ -94,6 +98,7 @@ def _conv2d(inputs, attrs):
new_attrs['dilation'] = attrs.get('dilate', (1, 1)) new_attrs['dilation'] = attrs.get('dilate', (1, 1))
new_attrs['groups'] = attrs.get('num_group', 1) new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout new_attrs['layout'] = layout
new_attrs['kernel_layout'] = kernel_layout
new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False' new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False'
return _get_nnvm_op(op_name)(*inputs, **new_attrs) return _get_nnvm_op(op_name)(*inputs, **new_attrs)
...@@ -106,6 +111,10 @@ def _conv2d_transpose(inputs, attrs): ...@@ -106,6 +111,10 @@ def _conv2d_transpose(inputs, attrs):
layout = attrs.get('layout', 'NCHW') layout = attrs.get('layout', 'NCHW')
if layout not in ['NCHW', 'NHWC']: if layout not in ['NCHW', 'NHWC']:
_raise_not_supported('layout: ' + layout, 'conv2d_transpose') _raise_not_supported('layout: ' + layout, 'conv2d_transpose')
if 'kernel_layout' in attrs:
kernel_layout = attrs['kernel_layout']
else:
kernel_layout = 'HWIO' if layout == 'NHWC' else 'OIHW'
op_name, new_attrs = 'conv2d_transpose', {} op_name, new_attrs = 'conv2d_transpose', {}
new_attrs['channels'] = _required_attr(attrs, 'num_filter') new_attrs['channels'] = _required_attr(attrs, 'num_filter')
new_attrs['kernel_size'] = kernel new_attrs['kernel_size'] = kernel
...@@ -115,6 +124,7 @@ def _conv2d_transpose(inputs, attrs): ...@@ -115,6 +124,7 @@ def _conv2d_transpose(inputs, attrs):
new_attrs['dilation'] = attrs.get('dilate', (1, 1)) new_attrs['dilation'] = attrs.get('dilate', (1, 1))
new_attrs['groups'] = attrs.get('num_group', 1) new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout new_attrs['layout'] = layout
new_attrs['kernel_layout'] = kernel_layout
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
return _get_nnvm_op(op_name)(*inputs, **new_attrs) return _get_nnvm_op(op_name)(*inputs, **new_attrs)
...@@ -237,7 +247,7 @@ _convert_map = { ...@@ -237,7 +247,7 @@ _convert_map = {
'min_axis' : _rename('min'), 'min_axis' : _rename('min'),
'reshape' : _reshape, 'reshape' : _reshape,
'sum_axis' : _rename('sum'), 'sum_axis' : _rename('sum'),
'UpSampling' : _upsampling 'UpSampling' : _upsampling,
} }
def _convert_symbol(op_name, inputs, attrs, def _convert_symbol(op_name, inputs, attrs,
......
...@@ -16,6 +16,7 @@ from . import _base ...@@ -16,6 +16,7 @@ from . import _base
from ._base import _LIB, check_call as _check_call, _FFI_MODE, _all_var_init from ._base import _LIB, check_call as _check_call, _FFI_MODE, _all_var_init
from .attribute import AttrScope from .attribute import AttrScope
from . import _symbol_internal as _internal from . import _symbol_internal as _internal
from . import contrib
# Use different verison of SymbolBase # Use different verison of SymbolBase
# When possible, use cython to speedup part of computation. # When possible, use cython to speedup part of computation.
......
...@@ -5,7 +5,7 @@ from __future__ import absolute_import ...@@ -5,7 +5,7 @@ from __future__ import absolute_import
import tvm import tvm
import topi import topi
from topi.util import get_const_int from topi.util import get_const_int
from .tensor import _fschedule_broadcast from .tensor import _fschedule_broadcast, _fschedule_injective
from . import registry as reg from . import registry as reg
from .registry import OpPattern from .registry import OpPattern
...@@ -32,6 +32,11 @@ reg.register_schedule("pad", _fschedule_broadcast) ...@@ -32,6 +32,11 @@ reg.register_schedule("pad", _fschedule_broadcast)
reg.register_pattern("pad", OpPattern.INJECTIVE) reg.register_pattern("pad", OpPattern.INJECTIVE)
# layout transform
reg.register_schedule("__layout_transform__", _fschedule_injective)
reg.register_pattern("__layout_transform__", OpPattern.INJECTIVE)
@reg.register_schedule("softmax") @reg.register_schedule("softmax")
def schedule_softmax(_, outs, target): def schedule_softmax(_, outs, target):
"""Schedule definition of softmax""" """Schedule definition of softmax"""
...@@ -108,6 +113,42 @@ def schedule_conv2d(attrs, outs, target): ...@@ -108,6 +113,42 @@ def schedule_conv2d(attrs, outs, target):
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# convolution NCHWc
@reg.register_compute("_contrib_conv2d_NCHWc")
def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
"""Compute definition of conv2d NCHWc"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
kh, kw = attrs.get_int_tuple('kernel_size')
groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
assert dilation == (1, 1), "not support dilate now"
if groups == 1:
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw), strides, padding)
else:
raise ValueError("not support arbitrary group number > 1 for now")
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.broadcast_add(out, bias)
return out
@reg.register_schedule("_contrib_conv2d_NCHWc")
def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of conv2d NCHWc"""
groups = attrs.get_int("groups")
kh, kw = attrs.get_int_tuple('kernel_size')
oc = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
with tvm.target.create(target):
if groups == 1:
return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding, outs)
else:
raise ValueError("not support group number > 1 for now")
reg.register_pattern("_contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE)
# conv2d_transpose # conv2d_transpose
@reg.register_compute("conv2d_transpose") @reg.register_compute("conv2d_transpose")
......
...@@ -25,6 +25,7 @@ class OpPattern(object): ...@@ -25,6 +25,7 @@ class OpPattern(object):
_register_compute = tvm.get_global_func("nnvm._register_compute") _register_compute = tvm.get_global_func("nnvm._register_compute")
_register_schedule = tvm.get_global_func("nnvm._register_schedule") _register_schedule = tvm.get_global_func("nnvm._register_schedule")
_register_pattern = tvm.get_global_func("nnvm._register_pattern") _register_pattern = tvm.get_global_func("nnvm._register_pattern")
_register_alter_op_layout = tvm.get_global_func("nnvm.compiler._register_alter_op_layout")
def register_compute(op_name, f=None, level=10): def register_compute(op_name, f=None, level=10):
"""Register compute function for operator """Register compute function for operator
...@@ -93,3 +94,29 @@ def register_pattern(op_name, pattern, level=10): ...@@ -93,3 +94,29 @@ def register_pattern(op_name, pattern, level=10):
The priority level The priority level
""" """
_register_pattern(op_name, pattern, level) _register_pattern(op_name, pattern, level)
def register_alter_op_layout(op_name, f=None, level=10):
"""Register alter layout function for operator
Parameters
----------
op_name : str
The name of operator
f : function
The schedule function
level : int
The priority level
Returns
-------
fregister : function
Register function if f is not specified.
"""
def register(myf):
"""internal register function"""
_register_alter_op_layout(op_name, myf, level)
return myf
return register(f) if f else register
...@@ -294,7 +294,7 @@ int NNSymbolGetNumOutputs(SymbolHandle symbol, ...@@ -294,7 +294,7 @@ int NNSymbolGetNumOutputs(SymbolHandle symbol,
nn_uint *output_count) { nn_uint *output_count) {
Symbol *s = static_cast<Symbol*>(symbol); Symbol *s = static_cast<Symbol*>(symbol);
API_BEGIN(); API_BEGIN();
*output_count = static_cast<nn_uint>(s->outputs.size()); *output_count = static_cast<nn_uint>(s->outputs.size());
API_END(); API_END();
} }
......
/*!
* Copyright (c) 2018 by Contributors
* \file alter_op_layout.cc
* \brief Alter the operator layouts. Keep inferred layouts (if any) from previous stages.
* e.g., convolution may calculates faster with NCHW16c layout.
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/layout.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/pass_functions.h>
#include <tvm/tvm.h>
#include <algorithm>
#include <functional>
#include "./compile_engine.h"
#include "./graph_transform.h"
namespace nnvm {
namespace compiler {
namespace {
tvm::Array<tvm::Tensor> GetTensorInfo(const IndexedGraph& idx_graph,
const uint32_t nid,
const ShapeVector& shape_vec,
const DTypeVector& dtype_vec) {
tvm::Array<tvm::Tensor> vec;
for (uint32_t i = 0; i < idx_graph[nid].source->num_outputs(); ++i) {
tvm::Array<tvm::Expr> shape;
for (int64_t x : shape_vec[idx_graph.entry_id(nid, i)]) {
CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
shape.push_back(tvm::make_const(tvm::Int(32), x));
}
vec.push_back(tvm::placeholder(
shape, GetTVMType(dtype_vec[idx_graph.entry_id(nid, i)])));
}
return vec;
}
Graph AlterOpLayout(const Graph& src) {
static auto& falter_op_layout =
Op::GetAttr<nnvm::compiler::FTVMAlterOpLayout >("FTVMAlterOpLayout");
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = src.GetAttr<DTypeVector>("dtype");
const IndexedGraph& idx_graph = src.indexed_graph();
std::vector<std::vector<Layout> > in_layouts_of_node(idx_graph.num_nodes());
std::vector<std::vector<Layout> > out_layouts_of_node(idx_graph.num_nodes());
std::unordered_map<const Node*, uint32_t> new_nodes;
if (src.HasAttr("layout")) {
// record layouts so that LayoutTransform pass can fix layouts correctly,
// e.g., conv2d can be replaced by some contrib implement
// whose layout is different from the original one
// (which was imported from a model file).
const auto& layouts = src.GetAttr<std::vector<Layout> >("layout");
for (uint32_t nid = 0; nid < idx_graph.num_nodes(); ++nid) {
const auto &inode = idx_graph[nid];
if (falter_op_layout.count(inode.source->op())) {
// do not record input layouts of nodes that will be replaced.
continue;
}
std::vector<Layout> in_layout;
for (const auto& e : inode.inputs) {
in_layout.emplace_back(layouts[idx_graph.entry_id(e)]);
}
in_layouts_of_node[nid] = in_layout;
std::vector<Layout> out_layout;
for (uint i = 0; i < inode.source->num_outputs(); ++i) {
out_layout.emplace_back(layouts[idx_graph.entry_id(nid, i)]);
}
out_layouts_of_node[nid] = out_layout;
}
}
auto transform = [&](uint32_t nid,
const NodePtr& n,
std::vector<NodeEntry>* ret) {
nnvm::compiler::FTVMAlterOpLayout fn_alter_op_layout =
falter_op_layout.get(n->op(), nullptr);
if (fn_alter_op_layout == nullptr) {
new_nodes[n.get()] = nid;
return false;
}
// construct parameters for registered function
std::vector<Symbol> op_inputs;
tvm::Array<tvm::Tensor> tensor_infos;
CHECK_EQ(n->num_inputs(), idx_graph[nid].inputs.size());
for (uint32_t i = 0; i < n->num_inputs(); ++i) {
const nnvm::NodeEntry& input = n->inputs[i];
// input operator
Symbol op_input;
op_input.outputs.push_back(input);
op_inputs.push_back(op_input);
// input tinfo, extract from the original graph
// because it was where infer_shape & infer_type applied.
tvm::Array<tvm::Tensor> op_output_tinfos =
GetTensorInfo(idx_graph, idx_graph[nid].inputs[i].node_id,
shape_vec, dtype_vec);
tensor_infos.push_back(op_output_tinfos[input.index]);
}
// callback registered function to get a new operator.
auto op = fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos);
*ret = op.outputs;
return true;
};
Graph ret = nnvm::compiler::GraphTransform(src, transform);
if (src.HasAttr("layout")) {
// restore the layouts to return graph
const auto& ret_idx = ret.indexed_graph();
std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef());
for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) {
const auto& inode = ret_idx[nid];
if (new_nodes.count(inode.source)) {
const std::vector<Layout>& in_layouts =
in_layouts_of_node[new_nodes[inode.source]];
for (const auto& e : inode.inputs) {
ret_layouts[ret_idx.entry_id(e)] = in_layouts[e.index];
}
const std::vector<Layout>& out_layouts =
out_layouts_of_node[new_nodes[inode.source]];
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
ret_layouts[ret_idx.entry_id(nid, i)] = out_layouts[i];
}
}
}
// cannot call indexed_graph() before return the origin Graph,
// thus create a new one.
nnvm::Graph new_ret;
new_ret.outputs = ret.outputs;
new_ret.attrs["layout"] = std::make_shared<any>(std::move(ret_layouts));
return new_ret;
}
return ret;
}
// register pass
NNVM_REGISTER_PASS(AlterOpLayout)
.set_body(AlterOpLayout)
.set_change_graph(true);
} // namespace
} // namespace compiler
} // namespace nnvm
...@@ -362,7 +362,7 @@ bool Pool2DBackward( ...@@ -362,7 +362,7 @@ bool Pool2DBackward(
std::vector<FoldChainInfo>* in_axis) { std::vector<FoldChainInfo>* in_axis) {
using top::Pool2DParam; using top::Pool2DParam;
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed); const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
if (out_info.axis == 1 && param.layout == top::kNCHW) { if (out_info.axis == 1 && param.layout == "NCHW") {
(*in_axis)[0] = out_info; (*in_axis)[0] = out_info;
} }
return false; return false;
...@@ -376,7 +376,7 @@ bool Pool2DForward( ...@@ -376,7 +376,7 @@ bool Pool2DForward(
FoldChainInfo* out_info) { FoldChainInfo* out_info) {
using top::Pool2DParam; using top::Pool2DParam;
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed); const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
if ((*in_info)[0].axis == 1 && param.layout == top::kNCHW) { if ((*in_info)[0].axis == 1 && param.layout == "NCHW") {
*out_info = (*in_info)[0]; *out_info = (*in_info)[0];
} }
return false; return false;
...@@ -467,7 +467,7 @@ bool Conv2DScaleAxisBackward( ...@@ -467,7 +467,7 @@ bool Conv2DScaleAxisBackward(
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed); const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if (out_info.kind != kPending) return false; if (out_info.kind != kPending) return false;
// only optimize for nchw for now // only optimize for nchw for now
if (param.layout == top::kNCHW && out_info.axis == 1) { if (param.layout == "NCHW" && out_info.axis == 1) {
(*in_axis)[1].kind = kMulConsumer; (*in_axis)[1].kind = kMulConsumer;
(*in_axis)[1].axis = 0; (*in_axis)[1].axis = 0;
(*in_axis)[1].source = out_info.source; (*in_axis)[1].source = out_info.source;
...@@ -492,7 +492,7 @@ bool Conv2DScaleAxisForward( ...@@ -492,7 +492,7 @@ bool Conv2DScaleAxisForward(
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed); const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if ((*in_info)[0].kind != kPending) return false; if ((*in_info)[0].kind != kPending) return false;
// only optimize for nchw for now // only optimize for nchw for now
if (param.layout == top::kNCHW && (*in_info)[0].axis == 1) { if (param.layout == "NCHW" && (*in_info)[0].axis == 1) {
(*in_info)[1].kind = kMulConsumer; (*in_info)[1].kind = kMulConsumer;
(*in_info)[1].axis = 1; (*in_info)[1].axis = 1;
(*in_info)[1].source = (*in_info)[0].source; (*in_info)[1].source = (*in_info)[0].source;
......
/*!
* Copyright (c) 2017 by Contributors
* \file layout_transform.cc
* \brief Transforms layout.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/contrib_op_param.h>
namespace nnvm {
namespace compiler {
const TLayoutInfo& GetDefaultLayout() {
static TLayoutInfo default_layout = "default";
return default_layout;
}
nnvm::NodePtr CreateLayoutTransformNode(const std::string& src,
const std::string& dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform");
static int count = 0;
nnvm::NodePtr n = nnvm::Node::Create();
n->attrs.op = trans_op;
n->attrs.name = src + "_to_" + dst + std::to_string(count++);
n->attrs.dict["src_layout"] = src;
n->attrs.dict["dst_layout"] = dst;
n->op()->attr_parser(&(n->attrs));
return n;
}
/*!
* \brief A simple layout transform pass that will
* insert layout transform nodes automatically.
*/
nnvm::Graph LayoutTransform(nnvm::Graph src) {
static auto& op_layout_request =
nnvm::Op::GetAttr<FTVMLayoutRequest>("FTVMLayoutRequest");
static auto& op_vecop =
nnvm::Op::GetAttr<FTVMVectorizedOp>("FTVMVectorizedOp");
static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
const std::vector<TLayoutInfo>& input_layouts =
src.GetAttr<std::vector<TLayoutInfo> >("layout_inputs");
const IndexedGraph& idx = src.indexed_graph();
std::vector<TLayoutInfo> produce_vec(idx.num_node_entries(), GetDefaultLayout());
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);
// use op pattern to decide whether an op is map
auto is_map_op = [&](size_t nid) {
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque);
bool is_map = (pt <= kBroadcast);
if (pt == kBroadcast) {
for (const auto& e : idx[nid].inputs) {
if (shape_vec[idx.entry_id(nid, 0)] != shape_vec[idx.entry_id(e)]) {
is_map = false;
break;
}
}
}
return is_map;
};
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
nnvm::NodePtr new_node = nnvm::Node::Create();
*new_node = *(inode.source);
if (new_node->is_variable()) {
auto input_iter = std::find(
idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid);
CHECK(input_iter != idx.input_nodes().cend());
size_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter);
produce_vec[idx.entry_id(nid, 0)] = input_layouts[input_id];
mirror_vec[nid] = new_node;
continue;
}
if (op_vecop.count(inode.source->op())) {
new_node = op_vecop[inode.source->op()](inode.source);
new_node->inputs.resize(new_node->num_inputs());
}
// set up output and input layouts
std::vector<TLayoutInfo> request_ilayouts(new_node->num_inputs(), GetDefaultLayout());
if (op_layout_request.count(new_node->op())) {
std::vector<TLayoutInfo> produce_olayouts(new_node->num_outputs(), GetDefaultLayout());
CHECK(op_layout_request[new_node->op()](
new_node->attrs, &request_ilayouts, &produce_olayouts))
<< "Layout request fail";
CHECK_EQ(request_ilayouts.size(), new_node->num_inputs());
CHECK_EQ(produce_olayouts.size(), new_node->num_outputs());
for (size_t i = 0; i < new_node->num_outputs(); ++i) {
produce_vec[idx.entry_id(nid, i)] = produce_olayouts[i];
}
}
bool map_layout = is_map_op(nid);
if (map_layout) {
const TLayoutInfo& layout = produce_vec[idx.entry_id(inode.inputs[0])];
for (const auto& e : inode.inputs) {
if (produce_vec[idx.entry_id(e)] != layout) {
map_layout = false;
break;
}
}
if (map_layout) {
for (size_t i = 0; i < inode.source->num_outputs(); ++i) {
produce_vec[idx.entry_id(nid, i)] = layout;
}
}
}
for (size_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
const nnvm::NodePtr& in = mirror_vec[e.node_id];
new_node->inputs[i] =
nnvm::NodeEntry{in, e.index, e.version};
TLayoutInfo produce = produce_vec[idx.entry_id(e)];
TLayoutInfo request = request_ilayouts[i];
if (!map_layout && (produce != request)) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name =
idx[e.node_id].source->attrs.name + "_" + request;
tnode->inputs.emplace_back(new_node->inputs[i]);
new_node->inputs[i] = nnvm::NodeEntry{tnode, 0, 0};
}
}
mirror_vec[nid] = new_node;
}
std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : idx.outputs()) {
TLayoutInfo produce = produce_vec[idx.entry_id(e)];
if (produce != GetDefaultLayout()) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, GetDefaultLayout());
tnode->attrs.name =
idx[e.node_id].source->attrs.name + "_default";
tnode->inputs.emplace_back(
nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
outputs.emplace_back(nnvm::NodeEntry{tnode, 0, 0});
} else {
outputs.emplace_back(
nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
}
}
nnvm::Graph ret;
ret.outputs = std::move(outputs);
return ret;
}
} // namespace compiler
} // namespace nnvm
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/compiler/packed_func_ext.h> #include <nnvm/compiler/packed_func_ext.h>
#include <nnvm/compiler/op_attr_types.h> #include <nnvm/compiler/op_attr_types.h>
#include <tvm/runtime/c_runtime_api.h>
#include "./node_attr.h" #include "./node_attr.h"
#include "compile_engine.h" #include "compile_engine.h"
...@@ -62,6 +63,23 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._dict_keys") ...@@ -62,6 +63,23 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._dict_keys")
*rv = keys; *rv = keys;
}); });
TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout")
.set_body([](TVMArgs args, TVMRetValue *rv) {
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* f = new PackedFunc(args[1].operator PackedFunc());
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
auto fpack = [f](const NodeAttrs& attrs,
const Symbol& inputs,
const Array<Tensor>& tinfos) {
TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, tinfos);
CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info<Symbol>::code)
<< " expected " << "Symbol (code = " << tvm::runtime::extension_class_info<Symbol>::code
<< ") but get code = " << ret.type_code();
return *(static_cast<Symbol*>(ret.value().v_handle));
};
op.set_attr<FTVMAlterOpLayout>("FTVMAlterOpLayout", fpack, args[2]);
});
// custom version of TVM compute // custom version of TVM compute
TVM_REGISTER_GLOBAL("nnvm._register_compute") TVM_REGISTER_GLOBAL("nnvm._register_compute")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
...@@ -84,7 +102,7 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute") ...@@ -84,7 +102,7 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute")
TVM_REGISTER_GLOBAL("nnvm._register_schedule") TVM_REGISTER_GLOBAL("nnvm._register_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* f = new PackedFunc(args[1].operator PackedFunc()); PackedFunc* f = new PackedFunc(args[1].operator PackedFunc());
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]); Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
auto fschedule = [f](const NodeAttrs& attrs, auto fschedule = [f](const NodeAttrs& attrs,
......
...@@ -22,7 +22,8 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, ...@@ -22,7 +22,8 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
nnvm::NodeEntry beta, nnvm::NodeEntry beta,
nnvm::NodeEntry moving_mean, nnvm::NodeEntry moving_mean,
nnvm::NodeEntry moving_var, nnvm::NodeEntry moving_var,
TShape dshape) { TShape dshape,
TShape bshape) {
CHECK_NE(dshape.ndim(), 0); CHECK_NE(dshape.ndim(), 0);
CHECK(attrs.op); CHECK(attrs.op);
static const Op* bn_op = Op::Get("batch_norm"); static const Op* bn_op = Op::Get("batch_norm");
...@@ -60,13 +61,14 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, ...@@ -60,13 +61,14 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
"elemwise_add", bn_name + "_add_beta", {shift, beta}); "elemwise_add", bn_name + "_add_beta", {shift, beta});
} }
int axis = param.axis; int axis = param.axis;
scale = ExpandBiasToMatchAxis(scale, dshape.ndim(), 1, axis); scale = ExpandBiasToMatchAxis(scale, dshape.ndim()-bshape.ndim()+1, 1, axis);
shift = ExpandBiasToMatchAxis(shift, dshape.ndim(), 1, axis); shift = ExpandBiasToMatchAxis(shift, dshape.ndim()-bshape.ndim()+1, 1, axis);
NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data", NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data",
{data, scale}); {data, scale});
out = MakeNode("broadcast_add", bn_name + "_out", out = MakeNode("broadcast_add", bn_name + "_out",
{out, shift}); {out, shift});
// It is invalid to ref the other values of BN after infernece transform. // It is invalid to ref the other values of BN after inference transform.
NodeEntry undef = MakeNode("__undef__", "undef", {}); NodeEntry undef = MakeNode("__undef__", "undef", {});
return {out, undef, undef}; return {out, undef, undef};
} }
...@@ -87,7 +89,8 @@ Graph SimplifyInference(nnvm::Graph src) { ...@@ -87,7 +89,8 @@ Graph SimplifyInference(nnvm::Graph src) {
n->inputs[2], n->inputs[2],
n->inputs[3], n->inputs[3],
n->inputs[4], n->inputs[4],
shape_vec[idx.entry_id(nid, 0)]); shape_vec[idx.entry_id(nid, 0)],
shape_vec[idx.entry_id(nid, 1)]);
return true; return true;
} else if (n->op() == dropout_op) { } else if (n->op() == dropout_op) {
NodeEntry undef = MakeNode("__undef__", "undef", {}); NodeEntry undef = MakeNode("__undef__", "undef", {});
...@@ -101,7 +104,8 @@ Graph SimplifyInference(nnvm::Graph src) { ...@@ -101,7 +104,8 @@ Graph SimplifyInference(nnvm::Graph src) {
} }
NNVM_REGISTER_PASS(SimplifyInference) NNVM_REGISTER_PASS(SimplifyInference)
.set_body(SimplifyInference); .set_body(SimplifyInference)
.set_change_graph(true);
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
/*!
* Copyright (c) 2018 by Contributors
* \file correct_layout.cc
* \brief Infer and correct layout.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/layout.h>
namespace nnvm {
namespace pass {
nnvm::NodePtr CreateLayoutTransformNode(const Layout& src,
const Layout& dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__");
static int count = 0;
nnvm::NodePtr n = nnvm::Node::Create();
n->attrs.op = trans_op;
n->attrs.name = src.name() + "_to_" + dst.name() + std::to_string(count++);
n->attrs.dict["src_layout"] = src.name();
n->attrs.dict["dst_layout"] = dst.name();
n->op()->attr_parser(&(n->attrs));
return n;
}
using LayoutAttrDict = std::unordered_map<const Node*, std::vector<Layout> >;
/*!
* \brief A simple layout infer pass that will
* insert layout transform nodes automatically.
*/
nnvm::Graph CorrectLayout(nnvm::Graph src) {
static auto& op_infer_layout =
nnvm::Op::GetAttr<FInferLayout>("FInferLayout");
const IndexedGraph& idx = src.indexed_graph();
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);
// (new) NodePtr -> output_layouts
LayoutAttrDict new_layouts;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
nnvm::NodePtr new_node = nnvm::Node::Create();
*new_node = *(inode.source);
if (new_node->is_variable()) {
// Variable node. No operator. Only one output entry.
auto input_iter = std::find(
idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid);
CHECK(input_iter != idx.input_nodes().cend());
int64_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter);
if (src.HasAttr("layout_inputs")) {
new_layouts[new_node.get()] =
{src.GetAttr<std::vector<Layout> >("layout_inputs")[input_id]};
} else {
new_layouts[new_node.get()] = {Layout::Undef()};
}
mirror_vec[nid] = new_node;
continue;
}
const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs();
// set up output and input layouts
std::vector<Layout> request_ilayouts(num_inputs, Layout::Undef());
for (size_t i = 0; i < num_inputs; ++i) {
const IndexedGraph::NodeEntry& input_entry = inode.inputs[i];
const NodePtr& new_input_node = mirror_vec[input_entry.node_id];
CHECK(new_input_node != nullptr);
// fill inputs by previous node (DFS order) inferred layouts.
const auto& layouts_iter = new_layouts.find(new_input_node.get());
CHECK(layouts_iter != new_layouts.end());
request_ilayouts[i] = layouts_iter->second[input_entry.index];
}
// layouts produced by previous node.
std::vector<Layout> produce_ilayouts(request_ilayouts);
// input layouts from last pass of LayoutTransform (if apply)
std::vector<Layout> last_request_ilayouts(num_inputs, Layout::Undef());
// fill outputs by last pass of LayoutTransform (if apply)
std::vector<Layout> produce_olayouts(num_outputs, Layout::Undef());
if (src.HasAttr("layout")) {
const auto& layouts = src.GetAttr<std::vector<Layout> >("layout");
for (uint32_t i = 0; i < num_outputs; ++i) {
produce_olayouts[i] = layouts[idx.entry_id(nid, i)];
}
for (uint32_t i = 0; i < num_inputs; ++i) {
last_request_ilayouts[i] = layouts[idx.entry_id(inode.inputs[i])];
}
}
const auto& flayout = op_infer_layout[new_node->op()];
CHECK(flayout != nullptr) << "Attribute FInferLayout"
<< " is not registered by op " << inode.source->op()->name
<< " we are not able to complete layout transform.";
CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts))
<< "Layout infer fail";
CHECK_EQ(request_ilayouts.size(), num_inputs);
CHECK_EQ(produce_olayouts.size(), num_outputs);
// update new layouts
new_layouts[new_node.get()] = std::move(produce_olayouts);
for (uint32_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
const nnvm::NodePtr& in = mirror_vec[e.node_id];
new_node->inputs[i] = nnvm::NodeEntry{in, e.index, e.version};
// insert layout_transform if necessary
const Layout& produce = produce_ilayouts[i];
const Layout& request = request_ilayouts[i];
if (produce != request && produce.defined()) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name();
tnode->inputs.emplace_back(new_node->inputs[i]);
nnvm::NodeEntry tnode_output{tnode, 0, 0};
new_node->inputs[i] = tnode_output;
// layout produced by LayoutTransformNode
new_layouts[tnode.get()] = {request};
} else if (!produce.defined()) {
// do reverse infer
new_layouts[in.get()][e.index] = request;
}
}
mirror_vec[nid] = new_node;
}
std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : idx.outputs()) {
outputs.emplace_back(nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
}
nnvm::Graph ret;
ret.outputs = outputs;
// restore the layouts to return graph
const auto& ret_idx = ret.indexed_graph();
std::vector<Layout> ret_layouts(ret_idx.num_node_entries(), Layout::Undef());
for (uint32_t nid = 0; nid < ret_idx.num_nodes(); ++nid) {
const auto& inode = ret_idx[nid];
const auto& layout_iter = new_layouts.find(inode.source);
if (layout_iter != new_layouts.end()) {
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
ret_layouts[ret_idx.entry_id(nid, i)] = std::move(layout_iter->second[i]);
}
}
}
// cannot call indexed_graph() before return the origin Graph,
// thus create a new one
nnvm::Graph new_ret;
new_ret.outputs = std::move(outputs);
new_ret.attrs["layout"] = std::make_shared<any>(std::move(ret_layouts));
return new_ret;
}
// register pass
NNVM_REGISTER_PASS(CorrectLayout)
.describe("Return a layout-transformed graph of src.")
.set_body(CorrectLayout)
.provide_graph_attr("layout")
.set_change_graph(true);
DMLC_JSON_ENABLE_ANY(LayoutVector, list_layout);
} // namespace pass
} // namespace nnvm
...@@ -158,7 +158,7 @@ Graph InferAttr(Graph &&ret, ...@@ -158,7 +158,7 @@ Graph InferAttr(Graph &&ret,
} else { } else {
CHECK(!last_iter) CHECK(!last_iter)
<< "Attribute " << infer_name << "Attribute " << infer_name
<< " is not registed by op " << inode.source->op()->name << " is not registered by op " << inode.source->op()->name
<< " we are not able to complete the inference because of this"; << " we are not able to complete the inference because of this";
} }
} }
......
...@@ -6,9 +6,12 @@ ...@@ -6,9 +6,12 @@
#ifndef NNVM_TOP_ELEMWISE_OP_COMMON_H_ #ifndef NNVM_TOP_ELEMWISE_OP_COMMON_H_
#define NNVM_TOP_ELEMWISE_OP_COMMON_H_ #define NNVM_TOP_ELEMWISE_OP_COMMON_H_
#include <nnvm/layout.h>
#include <nnvm/top/nn.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <functional>
#include "./op_common.h" #include "./op_common.h"
namespace nnvm { namespace nnvm {
...@@ -100,12 +103,176 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs, ...@@ -100,12 +103,176 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs,
attrs, in_attrs, out_attrs, -1); attrs, in_attrs, out_attrs, -1);
} }
template<int n_in, int n_out>
inline bool ElemwiseFixedLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts,
const std::function<Layout(const Layout& in)>& finfer) {
const size_t in_size = (n_in == -1) ? in_layouts->size() : static_cast<size_t>(n_in);
const size_t out_size = (n_out == -1) ? out_layouts->size() : static_cast<size_t>(n_out);
auto deduce = [&](Layout *target, const std::vector<Layout> *vec,
size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
if (vec->at(i).defined()) {
if (!target->defined()) {
*target = vec->at(i);
}
CHECK_EQ(*target, vec->at(i))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": " << "expected " << *target
<< ", got " << vec->at(i);
}
}
};
Layout in, last_in, out;
deduce(&in, in_layouts, in_size, "input");
deduce(&last_in, last_in_layouts, in_size, "input (last infer pass)");
deduce(&out, out_layouts, out_size, "output");
if (!last_in.defined()) {
last_in = in;
} else {
// else we copy in_layout produced by last infer pass to in_layout,
// and let LayoutTransform pass
// to insert an layout_transform node to fix the input layout.
in = last_in;
}
out = finfer(in);
auto write = [](std::vector<Layout> *vec, Layout& value, size_t size) {
for (size_t i = 0; i < size; ++i) {
vec->at(i) = value;
}
};
if (in.defined()) write(in_layouts, in, in_size);
if (out.defined()) write(out_layouts, out, out_size);
return true;
}
/*! \brief Fix the input layout as the previous inferred (if any) and copy to output */
template<int n_in, int n_out>
inline bool ElemwiseFixedLayoutCopyToOut(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
return ElemwiseFixedLayout<n_in, n_out>(
attrs, in_layouts, last_in_layouts, out_layouts, [](const Layout& in) {
return in;
});
}
/*! \brief Fix the input layout as the previous inferred (if any) and do not define output */
template<int n_in, int n_out>
inline bool ElemwiseFixedLayoutUnknownOut(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
return ElemwiseFixedLayout<n_in, n_out>(
attrs, in_layouts, last_in_layouts, out_layouts, [](const Layout& in) {
return Layout::Undef();
});
}
/*! \brief take arbitrary input layout and copy to output */
template<int n_in, int n_out>
inline bool ElemwiseArbitraryLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
const size_t in_size = (n_in == -1) ? in_layouts->size() : static_cast<size_t>(n_in);
const size_t out_size = (n_out == -1) ? out_layouts->size() : static_cast<size_t>(n_out);
Layout in;
for (size_t i = 0; i < in_size; ++i) {
if (!in.defined()) in = in_layouts->at(i);
CHECK_EQ(in, in_layouts->at(i))
<< "Incompatible attr in node " << attrs.name << " at " << i
<< "-th input: expected " << in
<< ", got " << in_layouts->at(i);
}
if (in.defined()) {
for (size_t i = 0; i < out_size; ++i) {
out_layouts->at(i) = in;
}
}
return true;
}
/*!
* \brief try to convert right layout to left layout if they are different.
* if the converting fails, it will use the last inferred layouts.
*/
inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
CHECK_EQ(in_layouts->size(), 2U);
CHECK_EQ(last_in_layouts->size(), 2U);
CHECK_EQ(out_layouts->size(), 1U);
const Layout& lhs_last = (*last_in_layouts)[0];
const Layout& rhs_last = (*last_in_layouts)[1];
CHECK((lhs_last.defined() && rhs_last.defined()) ||
(!lhs_last.defined() && !rhs_last.defined()));
const Layout& lhs = (*in_layouts)[0];
const Layout& rhs = (*in_layouts)[1];
if (!lhs.defined() && !rhs.defined()) {
CHECK(!lhs_last.defined() && !rhs_last.defined())
<< "Lost input layouts in node " << attrs.name
<< ": last inferred lhs=" << lhs_last << ", rhs=" << rhs_last;
return true;
} else if (!lhs.defined()) {
CHECK(!lhs_last.defined() && !rhs_last.defined());
in_layouts->at(0) = rhs;
out_layouts->at(0) = rhs;
return true;
} else if (!rhs.defined()) {
CHECK(!lhs_last.defined() && !rhs_last.defined());
in_layouts->at(1) = lhs;
out_layouts->at(0) = lhs;
return true;
}
if (lhs == rhs) {
// for same layout, we can always do binary calculation
// and pass the layout to next layer
out_layouts->at(0) = lhs;
return true;
}
if (rhs.convertible(lhs)) {
in_layouts->at(1) = lhs;
out_layouts->at(0) = lhs;
} else {
CHECK(lhs_last.defined() && rhs_last.defined())
<< "Incompatible input layouts in node " << attrs.name
<< ". lhs: " << lhs << ", rhs: " << rhs;
CHECK(lhs_last == rhs_last);
in_layouts->at(0) = lhs_last;
in_layouts->at(1) = rhs_last;
out_layouts->at(0) = lhs_last;
}
return true;
}
#define NNVM_REGISTER_ELEMWISE_UNARY_OP(name) \ #define NNVM_REGISTER_ELEMWISE_UNARY_OP(name) \
NNVM_REGISTER_OP(name) \ NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \ .set_num_inputs(1) \
.set_num_outputs(1) \ .set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) \ .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FInferLayout>("FInferLayout", \
ElemwiseArbitraryLayout<1, 1>) \
.set_attr<FInplaceOption>("FInplaceOption", \ .set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \ [](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \ return std::vector<std::pair<int, int> >{{0, 0}}; \
...@@ -131,6 +298,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs, ...@@ -131,6 +298,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs,
.set_num_outputs(1) \ .set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", ElemwiseShape<2, 1>) \ .set_attr<FInferShape>("FInferShape", ElemwiseShape<2, 1>) \
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<FInferLayout>("FInferLayout", \
ElemwiseBinaryKeepLeftLayout) \
.set_attr<FInplaceOption>("FInplaceOption", \ .set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \ [](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
...@@ -150,6 +319,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs, ...@@ -150,6 +319,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs,
ParamGetAttrDict<ElementWiseReduceParam>) \ ParamGetAttrDict<ElementWiseReduceParam>) \
.set_attr<nnvm::FInferShape>("FInferShape", \ .set_attr<nnvm::FInferShape>("FInferShape", \
ElementWiseReduceShape) \ ElementWiseReduceShape) \
.set_attr<FInferLayout>("FInferLayout", \
ElemwiseFixedLayoutCopyToOut<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElementWiseReduceType) \ .set_attr<nnvm::FInferType>("FInferType", ElementWiseReduceType) \
.add_argument("args", "Symbol[]", "Positional input arguments") .add_argument("args", "Symbol[]", "Positional input arguments")
...@@ -166,6 +337,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs, ...@@ -166,6 +337,8 @@ inline bool ElementWiseReduceType(const NodeAttrs& attrs,
static_cast<int>(kFloat32)); \ static_cast<int>(kFloat32)); \
return true; \ return true; \
}) \ }) \
.set_attr<FInferLayout>("FInferLayout", \
ElemwiseFixedLayoutUnknownOut<1, 1>) \
.set_attr<FGradient>( \ .set_attr<FGradient>( \
"FGradient", [](const NodePtr& n, \ "FGradient", [](const NodePtr& n, \
const std::vector<NodeEntry>& ograds) { \ const std::vector<NodeEntry>& ograds) { \
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/parameter.h> #include <dmlc/parameter.h>
#include <nnvm/layout.h>
#include <nnvm/top/nn.h> #include <nnvm/top/nn.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -40,100 +41,47 @@ inline std::vector<std::string> UseBiasListInputNames(const NodeAttrs& attrs) { ...@@ -40,100 +41,47 @@ inline std::vector<std::string> UseBiasListInputNames(const NodeAttrs& attrs) {
* \param dst_layout target layout * \param dst_layout target layout
* \return shape in target layout * \return shape in target layout
*/ */
inline TShape ConvertLayout(TShape src, int src_layout, int dst_layout, bool is_weight = false) { inline TShape ConvertLayout(TShape src, const Layout& src_layout, const Layout& dst_layout) {
if (src_layout == dst_layout) return src; if (src_layout == dst_layout) {
TShape dst = src; return src;
if (src.ndim() == 3) { } else if (!src_layout.defined()) {
switch (src_layout) { LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
case kNCW: break; } else if (!dst_layout.defined()) {
case kNWC: { LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
std::swap(dst[1], dst[2]); }
break;
} CHECK(src_layout.convertible(dst_layout)) << "cannot convert from "
default: { << src_layout << " to " << dst_layout;
LOG(FATAL) << "inavlid layout for 3d shape" << src_layout;
} TShape dst(dst_layout.ndim());
} for (size_t i = 0; i < src_layout.ndim(); ++i) {
switch (dst_layout) { Layout::LayoutDim src_dim = src_layout[i];
case kNCW: break; if (Layout::is_superdim(src_dim)) {
case kNWC: { int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_dim));
std::swap(dst[1], dst[2]); int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_dim));
break; int src_minor_pos = src_layout.indexof(Layout::to_subdim(src_dim));
} int src_factor = src_layout.subsizeof(src_dim);
default: { int dst_factor = dst_layout.subsizeof(src_dim);
LOG(FATAL) << "inavlid layout for 3d shape" << dst_layout;
} uint32_t src_dim_size = src[i];
} if (src_minor_pos >= 0) {
} else if (src.ndim() == 4) { CHECK_EQ(src_factor, src[src_minor_pos]) << "src shape " << src
switch (src_layout) { << " does not agree with layout " << src_layout;
case kNCHW: break; src_dim_size *= src_factor;
case kNHWC: {
if (is_weight) {
dst[2] = src[0];
dst[3] = src[1];
dst[1] = src[2];
dst[0] = src[3];
} else {
dst[2] = src[1];
dst[3] = src[2];
dst[1] = src[3];
}
break;
}
default: {
LOG(FATAL) << "inavlid layout for 4d shape" << src_layout;
}
}
src = dst;
switch (dst_layout) {
case kNCHW: break;
case kNHWC: {
if (is_weight) {
dst[0] = src[2];
dst[1] = src[3];
dst[2] = src[1];
dst[3] = src[0];
} else {
dst[1] = src[2];
dst[2] = src[3];
dst[3] = src[1];
}
break;
}
default: {
LOG(FATAL) << "inavlid layout for 4d shape" << dst_layout;
}
}
} else if (src.ndim() == 5) {
switch (src_layout) {
case kNCDHW: break;
case kNDHWC: {
dst[2] = src[1];
dst[3] = src[2];
dst[4] = src[3];
dst[1] = src[4];
break;
}
default: {
LOG(FATAL) << "inavlid layout for 5d shape" << src_layout;
}
}
src = dst;
switch (dst_layout) {
case kNCDHW: break;
case kNDHWC: {
dst[1] = src[2];
dst[2] = src[3];
dst[3] = src[4];
dst[4] = src[1];
break;
} }
default: {
LOG(FATAL) << "inavlid layout for 5d shape" << dst_layout; dst[dst_major_pos] = src_dim_size;
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
CHECK_LE(dst_factor, src_dim_size) << "Converting " << src
<< " from " << src_layout
<< " to " << dst_factor
<< ": cannot split dimension size of "
<< src_dim_size << " by " << dst_factor;
dst[dst_major_pos] /= dst_factor;
dst[dst_minor_pos] = dst_factor;
} }
} }
} else {
LOG(FATAL) << "no layout option for " << dst.ndim() << " dimensions";
} }
return dst; return dst;
} }
......
...@@ -19,6 +19,7 @@ DMLC_REGISTER_PARAMETER(UpSamplingParam); ...@@ -19,6 +19,7 @@ DMLC_REGISTER_PARAMETER(UpSamplingParam);
inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs, inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape, std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) { std::vector<TShape>* out_shape) {
static const Layout kNCHW("NCHW");
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed); const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U); CHECK_EQ(in_shape->size(), 1U);
CHECK_EQ(out_shape->size(), 1U); CHECK_EQ(out_shape->size(), 1U);
...@@ -33,6 +34,19 @@ inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs, ...@@ -33,6 +34,19 @@ inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool UpsamplingLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
CHECK_EQ(in_layouts->size(), 1U);
CHECK_EQ(out_layouts->size(), 1U);
const Layout layout(param.layout);
NNVM_ASSIGN_LAYOUT(*in_layouts, 0, layout);
NNVM_ASSIGN_LAYOUT(*out_layouts, 0, layout);
return true;
}
NNVM_REGISTER_OP(upsampling) NNVM_REGISTER_OP(upsampling)
.describe(R"(Perform nearest neighbor upsampling to input array. .describe(R"(Perform nearest neighbor upsampling to input array.
...@@ -46,6 +60,7 @@ NNVM_REGISTER_OP(upsampling) ...@@ -46,6 +60,7 @@ NNVM_REGISTER_OP(upsampling)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<UpSamplingParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<UpSamplingParam>)
.set_attr<FInferShape>("FInferShape", UpSamplingInferShape) .set_attr<FInferShape>("FInferShape", UpSamplingInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", UpsamplingLayout)
.set_num_outputs(1) .set_num_outputs(1)
.set_num_inputs(1) .set_num_inputs(1)
.set_support_level(2); .set_support_level(2);
......
...@@ -203,6 +203,13 @@ inline std::string attr_assign_error_msg(const NodeAttrs& attrs, ...@@ -203,6 +203,13 @@ inline std::string attr_assign_error_msg(const NodeAttrs& attrs,
} \ } \
} }
#define NNVM_ASSIGN_LAYOUT(outputs, index, layout) \
{ \
if (layout.defined()) { \
(outputs)[index] = layout; \
} \
}
/*! /*!
* \brief macro assign rhs shape to lhs * \brief macro assign rhs shape to lhs
* Use macro so we can see the error file more clearly * Use macro so we can see the error file more clearly
...@@ -253,6 +260,14 @@ inline bool ZeroShape(const NodeAttrs& attrs, ...@@ -253,6 +260,14 @@ inline bool ZeroShape(const NodeAttrs& attrs,
} }
} }
// do not infer layout
inline bool ZeroLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
return true;
}
// simply assign output shape or type from input // simply assign output shape or type from input
template<typename AttrType, int in_index, int out_index> template<typename AttrType, int in_index, int out_index>
inline bool AssignOutputAttr(const NodeAttrs& attrs, inline bool AssignOutputAttr(const NodeAttrs& attrs,
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <nnvm/compiler/op_attr_types.h> #include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h> #include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h> #include <nnvm/top/tensor.h>
#include <nnvm/top/nn.h>
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/broadcast.h" #include "topi/broadcast.h"
...@@ -74,6 +75,7 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example. ...@@ -74,6 +75,7 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example.
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<BroadcastToParam>)
.set_attr<FInferShape>("FInferShape", BroadcastToInferShape) .set_attr<FInferShape>("FInferShape", BroadcastToInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -115,7 +117,7 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, ...@@ -115,7 +117,7 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
} else { } else {
CHECK(l == 1 || r == 1) CHECK(l == 1 || r == 1)
<< "operands could not be broadcast together with shapes " << "operands could not be broadcast together with shapes "
<< lhs << " " << rhs; << lhs << " " << rhs << ", l=" << l << ", r=" << r;
out[i] = std::max(l, r); out[i] = std::max(l, r);
} }
} else { } else {
...@@ -126,6 +128,77 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, ...@@ -126,6 +128,77 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool BinaryBroadcastInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), 2U);
CHECK_EQ(olayouts->size(), 1U);
Layout lhs = (*ilayouts)[0];
Layout rhs = (*ilayouts)[1];
Layout out(Layout::Undef());
if (lhs.defined() && rhs.defined()) {
if (lhs == rhs) {
NNVM_ASSIGN_LAYOUT(*olayouts, 0, lhs);
return true;
}
// For example, NCHW <-> CHW, N16nCH16cW <-> HCW16c, etc, are broadcast-convertible
// because as the definition, CHW can broadcast with NCHW.
// For the second case, we can convert HCW16c to CH16cW then it can broadcast with N16nCH16cW.
// But CNHW <-> CHW, NCHW16n <-> CHW are not,
// because not matter how we adjust the layout of 'CHW',
// we can never have an 'N' between 'C' and "HW".
size_t l_start = 0, r_start = 0;
size_t l = 0, r = 0;
bool find_first_match = false;
while (l < lhs.ndim() && r < rhs.ndim()) {
if (!rhs.contains(Layout::to_superdim(lhs[l]))) {
CHECK(!find_first_match) << lhs << " and " << rhs << " are not broadcast-convertible";
l_start = ++l;
} else if (!lhs.contains(Layout::to_superdim(rhs[r]))) {
CHECK(!find_first_match) << lhs << " and " << rhs << " are not broadcast-convertible";
r_start = ++r;
} else {
find_first_match = true;
++l; ++r;
}
}
if (l_start > 0 && r_start > 0) {
LOG(FATAL) << lhs << " and " << rhs << " are not broadcast-convertible";
} else if (l_start > 0) {
rhs = lhs.sublayout(l_start, lhs.ndim()-l_start);
out = lhs;
} else if (r_start > 0) {
lhs = rhs.sublayout(r_start, rhs.ndim()-r_start);
out = rhs;
} else {
// prior to keep left layout
rhs = lhs;
out = lhs;
}
} else if (lhs.defined()) {
const Layout& last_lhs = last_ilayouts->at(0);
if (last_lhs.defined()) {
CHECK(lhs.convertible(last_lhs)) << "current lhs layout " << lhs
<< " cannot be converted to the original one " << last_lhs;
lhs = last_lhs;
// cannot decide output layout
}
} else if (rhs.defined()) {
const Layout& last_rhs = last_ilayouts->at(1);
if (last_rhs.defined()) {
CHECK(rhs.convertible(last_rhs)) << "current rhs layout " << rhs
<< " cannot be converted to the original one " << last_rhs;
rhs = last_rhs;
// cannot decide output layout
}
}
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, lhs);
NNVM_ASSIGN_LAYOUT(*ilayouts, 1, rhs);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, out);
return true;
}
#define NNVM_REGISTER_BINARY_BROADCAST_OP(name) \ #define NNVM_REGISTER_BINARY_BROADCAST_OP(name) \
NNVM_REGISTER_OP(name) \ NNVM_REGISTER_OP(name) \
...@@ -133,6 +206,8 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs, ...@@ -133,6 +206,8 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
.set_num_outputs(1) \ .set_num_outputs(1) \
.set_attr<FInferShape>("FInferShape", BinaryBroadcastShape) \ .set_attr<FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<FInferLayout>("FInferLayout", \
BinaryBroadcastInferLayout) \
.set_attr<FInplaceOption>("FInplaceOption", \ .set_attr<FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs) { \ [](const NodeAttrs& attrs) { \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \ return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
......
...@@ -333,6 +333,7 @@ NNVM_REGISTER_INIT_OP(full) ...@@ -333,6 +333,7 @@ NNVM_REGISTER_INIT_OP(full)
.add_arguments(InitOpWithScalarParam::__FIELDS__()) .add_arguments(InitOpWithScalarParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>)
.set_attr<FInferLayout>("FInferLayout", ZeroLayout)
.set_support_level(4); .set_support_level(4);
NNVM_REGISTER_INIT_OP(zeros) NNVM_REGISTER_INIT_OP(zeros)
...@@ -345,6 +346,7 @@ NNVM_REGISTER_INIT_OP(zeros) ...@@ -345,6 +346,7 @@ NNVM_REGISTER_INIT_OP(zeros)
.add_arguments(InitOpParam::__FIELDS__()) .add_arguments(InitOpParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_attr<FInferLayout>("FInferLayout", ZeroLayout)
.set_support_level(4); .set_support_level(4);
NNVM_REGISTER_INIT_OP(ones) NNVM_REGISTER_INIT_OP(ones)
...@@ -357,6 +359,7 @@ NNVM_REGISTER_INIT_OP(ones) ...@@ -357,6 +359,7 @@ NNVM_REGISTER_INIT_OP(ones)
.add_arguments(InitOpParam::__FIELDS__()) .add_arguments(InitOpParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>) .set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>) .set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_attr<FInferLayout>("FInferLayout", ZeroLayout)
.set_support_level(4); .set_support_level(4);
// full_like // full_like
...@@ -693,6 +696,7 @@ Example:: ...@@ -693,6 +696,7 @@ Example::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ClipParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ClipParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
......
...@@ -41,6 +41,31 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, ...@@ -41,6 +41,31 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool DotInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const MatMulParam& param = nnvm::get<MatMulParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), 2U);
CHECK_EQ(olayouts->size(), 1U);
const Layout& lhs = last_ilayouts->at(0).defined() ? last_ilayouts->at(0)
: ilayouts->at(0);
const Layout& rhs = last_ilayouts->at(1).defined() ? last_ilayouts->at(1)
: ilayouts->at(1);
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, lhs);
NNVM_ASSIGN_LAYOUT(*ilayouts, 1, rhs);
if (lhs.ndim() > 1 && rhs.ndim() > 1) {
// concat lhs and rhs layout
const Layout& lhs_out = param.transpose_a ? lhs.reverse() : lhs;
const Layout& rhs_out = param.transpose_b ? rhs.reverse() : rhs;
Layout out = std::move(lhs_out.sublayout(0, lhs_out.ndim()-1) +
rhs_out.sublayout(1, rhs_out.ndim()-1));
NNVM_ASSIGN_LAYOUT(*olayouts, 0, out);
}
return true;
}
NNVM_REGISTER_OP(matmul) NNVM_REGISTER_OP(matmul)
.describe(R"doc(Matrix multiplication of two arrays. .describe(R"doc(Matrix multiplication of two arrays.
...@@ -67,6 +92,7 @@ NNVM_REGISTER_OP(matmul) ...@@ -67,6 +92,7 @@ NNVM_REGISTER_OP(matmul)
.add_argument("rhs", "NDArray-or-Symbol", "The second input") .add_argument("rhs", "NDArray-or-Symbol", "The second input")
.set_attr<FInferShape>("FInferShape", DotShape) .set_attr<FInferShape>("FInferShape", DotShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FInferLayout>("FInferLayout", DotInferLayout)
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
......
...@@ -111,6 +111,8 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) { ...@@ -111,6 +111,8 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \ .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \
.set_attr<FInferShape>("FInferShape", ReduceShape) \ .set_attr<FInferShape>("FInferShape", ReduceShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \ .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FInferLayout>("FInferLayout", \
ElemwiseFixedLayoutUnknownOut<1, 1>) \
.set_num_inputs(1) \ .set_num_inputs(1) \
.set_num_outputs(1) .set_num_outputs(1)
......
...@@ -45,6 +45,15 @@ This is an experimental operator. ...@@ -45,6 +45,15 @@ This is an experimental operator.
return Array<Tensor>{ topi::identity(inputs[1]) }; return Array<Tensor>{ topi::identity(inputs[1]) };
}) })
.set_attr<FInferShape>("FInferShape", SameShape) .set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInferLayout>(
"FInferLayout", [](const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
NNVM_ASSIGN_LAYOUT(*in_layouts, 1, (*in_layouts)[0]);
NNVM_ASSIGN_LAYOUT(*out_layouts, 0, (*in_layouts)[0]);
return true;
})
.set_attr<FInplaceOption>( .set_attr<FInplaceOption>(
"FInplaceOption", [](const NodeAttrs& attrs) { "FInplaceOption", [](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{1, 0}}; return std::vector<std::pair<int, int> >{{1, 0}};
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <nnvm/compiler/util.h> #include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h> #include <nnvm/top/tensor.h>
#include <cctype> #include <cctype>
#include <sstream>
#include "../op_common.h" #include "../op_common.h"
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/nn/flatten.h" #include "topi/nn/flatten.h"
...@@ -63,6 +64,7 @@ Example:: ...@@ -63,6 +64,7 @@ Example::
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", FlattenInferShape) .set_attr<FInferShape>("FInferShape", FlattenInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
...@@ -119,6 +121,22 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs, ...@@ -119,6 +121,22 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs,
return dshape.Size() != 0; return dshape.Size() != 0;
} }
inline bool ConcatenateInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
CHECK_EQ(olayouts->size(), 1U);
for (size_t i = 0; i < ilayouts->size(); ++i) {
const Layout& input = last_ilayouts->at(i).defined() ?
last_ilayouts->at(i) : ilayouts->at(i);
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
}
return true;
}
NNVM_REGISTER_OP(concatenate) NNVM_REGISTER_OP(concatenate)
.describe(R"code(Joins input arrays along a given axis. .describe(R"code(Joins input arrays along a given axis.
...@@ -156,6 +174,7 @@ Example:: ...@@ -156,6 +174,7 @@ Example::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ConcatenateParam>)
.set_attr<FInferShape>("FInferShape", ConcatenateInferShape) .set_attr<FInferShape>("FInferShape", ConcatenateInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferLayout>("FInferLayout", ConcatenateInferLayout)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs, "FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
...@@ -177,7 +196,8 @@ inline bool ExpandDimsInferShape(const NodeAttrs& attrs, ...@@ -177,7 +196,8 @@ inline bool ExpandDimsInferShape(const NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), 1U); CHECK_EQ(in_shape->size(), 1U);
const TShape& dshape = in_shape->at(0); const TShape& dshape = in_shape->at(0);
int ndim = static_cast<int>(dshape.ndim()); int ndim = static_cast<int>(dshape.ndim());
CHECK(param.axis >= -ndim - 1 && param.axis <= ndim); CHECK(param.axis >= -ndim - 1 && param.axis <= ndim)
<< "with axis = " << param.axis << " ndim = " << ndim;
int axis = param.axis < 0 ? ndim + param.axis + 1 : param.axis; int axis = param.axis < 0 ? ndim + param.axis + 1 : param.axis;
std::vector<dim_t> oshape; std::vector<dim_t> oshape;
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
...@@ -198,7 +218,7 @@ NNVM_REGISTER_OP(expand_dims) ...@@ -198,7 +218,7 @@ NNVM_REGISTER_OP(expand_dims)
.describe(R"code(Inserts a new axis of size 1 into the array shape .describe(R"code(Inserts a new axis of size 1 into the array shape
For example, given ``x`` with shape ``(2,3,4)``, then ``expand_dims(x, axis=1, num_newaxis=5)`` For example, given ``x`` with shape ``(2,3,4)``, then ``expand_dims(x, axis=1, num_newaxis=5)``
will return a new array with shape ``(2,5,3,4)``. will return a new array with shape ``(2,1,1,1,1,1,3,4)``.
)code" NNVM_ADD_FILELINE) )code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input tensor") .add_argument("data", "Tensor", "Input tensor")
...@@ -207,6 +227,7 @@ will return a new array with shape ``(2,5,3,4)``. ...@@ -207,6 +227,7 @@ will return a new array with shape ``(2,5,3,4)``.
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ExpandDimsParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ExpandDimsParam>)
.set_attr<FInferShape>("FInferShape", ExpandDimsInferShape) .set_attr<FInferShape>("FInferShape", ExpandDimsInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -249,6 +270,8 @@ Examples:: ...@@ -249,6 +270,8 @@ Examples::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<IndicatorParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<IndicatorParam>)
.set_attr<nnvm::FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>) .set_attr<nnvm::FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
// never transform layout of the second input array.
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(2) .set_num_inputs(2)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FGradient>( .set_attr<FGradient>(
...@@ -345,6 +368,7 @@ along which to split the array. ...@@ -345,6 +368,7 @@ along which to split the array.
.set_attr_parser(SplitParamParser) .set_attr_parser(SplitParamParser)
.set_attr<FInferShape>("FInferShape", SplitInferShape) .set_attr<FInferShape>("FInferShape", SplitInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, -1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, -1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, -1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(SplitNumOutputs) .set_num_outputs(SplitNumOutputs)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -387,6 +411,7 @@ NNVM_REGISTER_OP(cast) ...@@ -387,6 +411,7 @@ NNVM_REGISTER_OP(cast)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<CastParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<CastParam>)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) .set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", CastInferType) .set_attr<FInferType>("FInferType", CastInferType)
.set_attr<FInferLayout>("FInferLayout", ElemwiseArbitraryLayout<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_support_level(1); .set_support_level(1);
...@@ -539,6 +564,7 @@ The significance of each is explained below: ...@@ -539,6 +564,7 @@ The significance of each is explained below:
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReshapeParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReshapeParam>)
.set_attr<FInferShape>("FInferShape", ReshapeInferShape) .set_attr<FInferShape>("FInferShape", ReshapeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -578,6 +604,8 @@ the input array into an output array with the same shape as the second input arr ...@@ -578,6 +604,8 @@ the input array into an output array with the same shape as the second input arr
return true; return true;
}) })
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) .set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
// never transform layout of the second input array.
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -660,6 +688,7 @@ Examples:: ...@@ -660,6 +688,7 @@ Examples::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SqueezeParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SqueezeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", SqueezeShape) .set_attr<nnvm::FInferShape>("FInferShape", SqueezeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", ElemwiseFixedLayoutUnknownOut<1, 1>)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
...@@ -680,7 +709,7 @@ Examples:: ...@@ -680,7 +709,7 @@ Examples::
}) })
.set_support_level(1); .set_support_level(1);
// tranpose // transpose
DMLC_REGISTER_PARAMETER(TransposeParam); DMLC_REGISTER_PARAMETER(TransposeParam);
inline bool TransposeShape(const nnvm::NodeAttrs& attrs, inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
...@@ -708,6 +737,39 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, ...@@ -708,6 +737,39 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
inline bool TransposeInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), 1U);
CHECK_EQ(olayouts->size(), 1U);
const Layout& input = last_ilayouts->at(0).defined()
? last_ilayouts->at(0)
: ilayouts->at(0);
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, input);
if (input.defined()) {
std::ostringstream new_layout;
if (param.axes.ndim() == 0) {
for (size_t i = 0; i < input.ndim(); ++i) {
new_layout << input.at(input.ndim() - 1 - i);
}
} else {
CHECK_EQ(input.ndim(), param.axes.ndim());
for (size_t i = 0; i < input.ndim(); ++i) {
CHECK(param.axes[i] < input.ndim());
new_layout << input.at(param.axes[i]);
}
}
NNVM_ASSIGN_LAYOUT(*olayouts, 0, Layout(new_layout.str()));
}
return true;
}
NNVM_REGISTER_OP(transpose) NNVM_REGISTER_OP(transpose)
.describe(R"code(Permutes the dimensions of an array. .describe(R"code(Permutes the dimensions of an array.
...@@ -743,6 +805,7 @@ Examples:: ...@@ -743,6 +805,7 @@ Examples::
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<TransposeParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<TransposeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", TransposeShape) .set_attr<nnvm::FInferShape>("FInferShape", TransposeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferLayout>("FInferLayout", TransposeInferLayout)
.set_num_inputs(1) .set_num_inputs(1)
.set_num_outputs(1) .set_num_outputs(1)
.set_support_level(4) .set_support_level(4)
......
"""Unittest cases for AlterOpLayout pass"""
from nnvm import symbol as sym
from nnvm.compiler import graph_attr
from nnvm.top import registry as reg
import nnvm.graph as graph
def get_layouts(g):
ldict = {}
vlayout = g.json_attr("layout")
entry_ptr = g.index.entry_ptr
for i, n in enumerate(g.index.nodes):
begin, end = entry_ptr[i], entry_ptr[i + 1]
ldict[n["name"]] = vlayout[begin:end]
return ldict
def test_alter_conv2d_layout():
data = sym.Variable("data", shape=(1, 32, 512, 512))
conv = sym.conv2d(data, name="conv", channels=16,
kernel_size=(3,3), padding=(1,1),
use_bias=False, layout="NCHW")
relu = sym.relu(conv, name="relu")
flatten = sym.flatten(relu, name="flatten")
softmax = sym.softmax(flatten, name="softmax")
g = graph.create(softmax)
g = g.apply("CorrectLayout")
g = graph_attr.set_dtype_inputs(g, "float32")
g = g.apply(["InferShape", "InferType"])
layouts_origin = get_layouts(g)
@reg.register_alter_op_layout("conv2d")
def alter_conv2d_layout(attrs, inputs, tinfos):
new_attrs = {k : attrs[k] for k in attrs.keys()}
new_attrs["layout"] = "NCHW16c"
new_attrs["kernel_layout"] = "NCHW16c"
new_attrs["name"] = "conv_alter"
return sym.conv2d(inputs[0], inputs[1], **new_attrs)
g = g.apply("AlterOpLayout")
layouts = get_layouts(g)
# check copy layouts
for node in ["data", "relu", "flatten", "softmax", "conv_weight"]:
assert(layouts[node] == layouts_origin[node])
assert(layouts["conv_alter"] == layouts_origin["conv"])
if __name__ == "__main__":
test_alter_conv2d_layout()
...@@ -5,9 +5,10 @@ import nnvm.symbol as sym ...@@ -5,9 +5,10 @@ import nnvm.symbol as sym
import nnvm.compiler import nnvm.compiler
from nnvm.testing.config import ctx_list from nnvm.testing.config import ctx_list
def get_sym(layout, channels): def get_sym(layout, kernel_layout, channels):
data = sym.Variable(name="data") data = sym.Variable(name="data")
data = sym.conv2d(data=data, kernel_size=(3,3), channels=channels, padding=(1, 1), layout=layout, use_bias=True) data = sym.conv2d(data=data, kernel_size=(3,3), channels=channels, padding=(1, 1),
layout=layout, kernel_layout=kernel_layout, use_bias=True)
data = sym.max_pool2d(data=data, pool_size=(2, 2), strides=(2, 2), layout=layout) data = sym.max_pool2d(data=data, pool_size=(2, 2), strides=(2, 2), layout=layout)
data = sym.upsampling(data=data, scale=2, layout=layout) data = sym.upsampling(data=data, scale=2, layout=layout)
softmax_axis = 1 softmax_axis = 1
...@@ -31,8 +32,8 @@ def build_and_run(sym, params, data, out_shape): ...@@ -31,8 +32,8 @@ def build_and_run(sym, params, data, out_shape):
def test_nhwc(): def test_nhwc():
data_shape = (1, 3, 224, 224) data_shape = (1, 3, 224, 224)
out_channel = 8 out_channel = 8
nchw_sym = get_sym("NCHW", out_channel) nchw_sym = get_sym("NCHW", "OIHW", out_channel)
nhwc_sym = get_sym("NHWC", out_channel) nhwc_sym = get_sym("NHWC", "HWIO", out_channel)
conv_weight = np.random.uniform(-1, 1, (out_channel, 3, 3, 3)).astype(np.float32) conv_weight = np.random.uniform(-1, 1, (out_channel, 3, 3, 3)).astype(np.float32)
conv_bias = np.random.uniform(-1, 1, (out_channel)).astype(np.float32) conv_bias = np.random.uniform(-1, 1, (out_channel)).astype(np.float32)
nchw_params = { nchw_params = {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment