Commit 53ac89ed by Wuwei Lin Committed by Tianqi Chen

[RELAY][PASS] CombineParallelConv2D (#2089)

parent ff5dffa4
......@@ -13,6 +13,7 @@ from .backend import graph_runtime_codegen as _graph_gen
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"CombineParallelConv2D": 1,
"OpFusion": 1,
"FoldConstant": 2,
"FoldScaleAxis": 3,
......@@ -144,6 +145,10 @@ def optimize(func, params=None):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)
if cfg.pass_enabled("CombineParallelConv2D"):
func = ir_pass.infer_type(func)
func = ir_pass.combine_parallel_conv2d(func)
if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func)
func = ir_pass.backward_fold_scale_axis(func)
......
......@@ -292,3 +292,19 @@ def fuse_ops(expr, opt_level=1):
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr, opt_level)
def combine_parallel_conv2d(expr):
"""Fold multiple conv2d into one.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression
"""
return _ir_pass.CombineParallelConv2D(expr)
/*!
* Copyright (c) 2018 by Contributors
* \file expr_subst.h
* \brief Utility functions for substituting expressions.
*/
#include <tvm/relay/expr_functor.h>
#include "./expr_subst.h"
namespace tvm {
namespace relay {
class ExprSubstituter : public ExprMutator {
public:
explicit ExprSubstituter(std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map)
: subst_map_(subst_map) {}
Expr VisitExpr(const Expr& expr) final {
auto it = subst_map_.find(expr);
if (it != subst_map_.end()) {
return (*it).second;
}
return ExprMutator::VisitExpr(expr);
}
private:
tvm::Map<Expr, Expr> subst_map_;
};
Expr ExprSubst(const Expr& expr, std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map) {
return ExprSubstituter(std::move(subst_map)).Mutate(expr);
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file expr_subst.h
* \brief Utility functions for substituting expressions.
*/
#ifndef TVM_RELAY_PASS_EXPR_SUBST_H_
#define TVM_RELAY_PASS_EXPR_SUBST_H_
#include <tvm/relay/expr.h>
#include <unordered_map>
namespace tvm {
namespace relay {
Expr ExprSubst(const Expr& expr, std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_EXPR_SUBST_H_
......@@ -11,6 +11,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
#include <string>
#include "../op/layout.h"
......@@ -120,6 +121,19 @@ inline bool IsDepthwiseConv2D(const Call& call,
is_const_int(wshape[1], 1);
}
/*!
* \brief Get super-dimension of output channels of conv2d
* \param call The conv2d call.
* \return Super-dimension size of output channels of conv2d.
*/
inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
auto param = call->attrs.as<Conv2DAttrs>();
auto tweight = call->args[1]->type_as<TensorTypeNode>();
auto index = param->weight_layout.find('O');
CHECK_NE(index, std::string::npos);
auto channels = as_const_int(tweight->shape[index]);
return *channels;
}
/*!
* \brief Create a Constant with a scalar
......@@ -172,6 +186,10 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
from tvm import relay
import numpy as np
def test_combine_parallel_conv2d():
"""Simple testcase."""
def before(x, w1, w2, w3, w4):
args = [x, w1, w2, w3, w4]
y1 = relay.nn.conv2d(x, w1)
y2 = relay.nn.conv2d(x, w2)
# y3 cannot be combined
y3 = relay.nn.conv2d(x, w3)
y4 = relay.nn.conv2d(x, w4)
y = relay.Tuple((y1, y2, y3, y4))
return relay.Function(args, y)
def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4):
# use a fixed order of args so alpha equal check can pass
args = [x, w1, w2, w3, w4]
w = relay.concatenate((w1, w2, w4), axis=0)
y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4)
y1 = relay.strided_slice(y, [0, 0], [None, channels1])
y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2])
y3 = relay.nn.conv2d(x, w3)
y4 = relay.strided_slice(y, [0, channels1 + channels2],
[None, channels1 + channels2 + channels4])
y = relay.Tuple((y1, y2, y3, y4))
return relay.Function(args, y)
def check(x_shape, channels1, channels2, channels3, channels4):
x = relay.var("x", shape=x_shape)
in_c = x_shape[1]
w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
w3 = relay.var("w3", shape=(channels3, in_c, 3, 3))
w4 = relay.var("w4", shape=(channels4, in_c, 1, 1))
y_before = before(x, w1, w2, w3, w4)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
y_expected = relay.ir_pass.infer_type(y_expected)
assert relay.ir_pass.alpha_equal(y, y_expected)
check((1, 4, 16, 16), 4, 4, 4, 4)
check((1, 4, 16, 16), 4, 8, 4, 7)
def test_combine_parallel_conv2d_scale_relu():
"""Testcase of combining conv2d + scale + relu"""
def before(x, w1, w2, scale1, scale2, bias):
args = [x, w1, w2, scale1, scale2, bias]
y1 = relay.nn.conv2d(x, w1)
y1 = relay.multiply(y1, scale1)
y1 = relay.nn.relu(y1)
y2 = relay.nn.conv2d(x, w2)
y2 = relay.multiply(y2, scale2)
y2 = relay.nn.relu(y2)
y2 = relay.add(y2, bias)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def expected(x, w1, w2, scale1, scale2, bias, channels1, channels2):
args = [x, w1, w2, scale1, scale2, bias]
w = relay.concatenate((w1, w2), axis=0)
scale = relay.concatenate((scale1, scale2), axis=0)
y = relay.nn.conv2d(x, w, channels=channels1 + channels2)
y = relay.multiply(y, scale)
y = relay.nn.relu(y)
y1 = relay.strided_slice(y, [0, 0], [None, channels1])
y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2])
y2 = relay.add(y2, bias)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def check(x_shape, channels1, channels2):
x = relay.var("x", shape=x_shape)
in_c = x_shape[1]
w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
scale1 = relay.var("scale1", shape=(channels1, 1, 1))
scale2 = relay.var("scale2", shape=(channels2, 1, 1))
bias = relay.var("bias", shape=(channels2, 1, 1))
y_before = before(x, w1, w2, scale1, scale2, bias)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2)
y_expected = relay.ir_pass.infer_type(y_expected)
assert relay.ir_pass.alpha_equal(y, y_expected)
check((1, 4, 16, 16), 4, 8)
def test_combine_parallel_conv2d_scale():
"""Testcase of un-combinable scale"""
def before(x, w1, w2, scale1, scale2):
args = [x, w1, w2, scale1, scale2]
y1 = relay.nn.conv2d(x, w1)
y1 = relay.multiply(y1, scale1)
y2 = relay.nn.conv2d(x, w2)
y2 = relay.multiply(y2, scale2)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def expected(x, w1, w2, scale1, scale2, channels1, channels2):
args = [x, w1, w2, scale1, scale2]
w = relay.concatenate((w1, w2), axis=0)
y = relay.nn.conv2d(x, w, channels=channels1 + channels2)
y1 = relay.strided_slice(y, [0, 0], [None, channels1])
y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2])
y1 = relay.multiply(y1, scale1)
y2 = relay.multiply(y2, scale2)
y = relay.Tuple((y1, y2))
return relay.Function(args, y)
def check(x_shape, channels1, channels2):
x = relay.var("x", shape=x_shape)
in_c = x_shape[1]
w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
scale1 = relay.var("scale1", shape=(1,))
scale2 = relay.var("scale2", shape=(1,))
y_before = before(x, w1, w2, scale1, scale2)
y = relay.ir_pass.infer_type(y_before)
y = relay.ir_pass.combine_parallel_conv2d(y)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
y_expected = relay.ir_pass.infer_type(y_expected)
assert relay.ir_pass.alpha_equal(y, y_expected)
check((1, 4, 16, 16), 4, 8)
if __name__ == "__main__":
test_combine_parallel_conv2d()
test_combine_parallel_conv2d_scale_relu()
test_combine_parallel_conv2d_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment