Commit 1ed28aeb by masahi Committed by Tianqi Chen

[NNVM] Enhance operator fusion for more element wise patterns (#1548)

parent 0241fdc5
......@@ -161,6 +161,103 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
}
}
}
/*
Above algorithm will not fuse a node whose output is fed to more than one
child node. This is because in general, it does not make sense to fuse multiple
children branches with their parent, as in the following example.
conv2d
/ | \
/ | \
op op op
| | |
| | |
However, when all children branches meet at a certain node, there is a possibility for
further operator fusion. For example, all nodes in the following subgraph can be fused
into a single node, if three 'in-between' nodes and the bottom node are all element wise
operation.
conv2d
/ | \
/ | \
op op op
\ | /
\ | /
elemwise add
|
This pattern is not uncommon. For example, it arises when conv2d op is followed by exponential
linear unit. If bias add and batch normalization are also present, they can be fused as well.
In fact, above fusion algorithm already fuses three in-between nodes and the element wise
add node in the figure above. The following code fuses the conv2d node with the already
fused children nodes. The following patterns are supported.
* Any number of child nodes from the top node
* The path from the top node to bottom node can contain any number of element wise ops.
The only restriction is that in-between nodes cannot have more than one child.
The overview of the algorithm below is as follows:
1. Check if all children nodes are fused into a single op by the existing fusion algorithm
2. Fuse the parent node to children nodes, and update its group id to be the children's group id
3. If the parent node originally belongs to another group (for example, conv + batch norm),
propagate the new group id to a grand parent and upward
*/
if (opt_level >= 1) {
std::vector<std::vector<uint32_t> > children_group_ids(idx.num_nodes());
std::vector<std::vector<uint32_t> > node_ids_per_group(idx.num_nodes());
for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
CHECK_NE(group_vec[nid], -1);
node_ids_per_group[group_vec[nid]].push_back(nid);
if (inode.inputs.size() != 1) continue;
const uint32_t parent_nid = inode.inputs[0].node_id;
// if parent node has more than one child, record each child's group id.
if (ref_count[parent_nid] > 1) children_group_ids[parent_nid].push_back(group_vec[nid]);
}
std::vector<int> new_group_id(idx.num_nodes(), -1);
for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) {
if (new_group_id[group_vec[nid]] != -1) {
// propagate new group id from child
group_vec[nid] = new_group_id[group_vec[nid]];
}
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque);
if (pt == kOpaque) continue;
const auto& group_ids = children_group_ids[nid];
if (group_ids.size() <= 1) continue;
const uint32_t child_group_id = group_ids[0];
const auto& children_node_ids = node_ids_per_group[child_group_id];
auto is_same_group_id = [child_group_id](uint32_t id) {
return id == child_group_id;
};
auto is_fusible_pattern = [&idx](uint32_t child_nid) {
TOpPattern child_pt = op_pattern.get(idx[child_nid].source->op(), kOpaque);
return child_pt <= kBroadcast;
};
// fuse this node with children if
// all children belong to the same group and
// all nodes in the group are element wise or broadcast op.
const bool can_be_fused = std::all_of(group_ids.begin(), group_ids.end(), is_same_group_id) &&
std::all_of(children_node_ids.begin(), children_node_ids.end(), is_fusible_pattern);
if (can_be_fused) {
new_group_id[group_vec[nid]] = child_group_id;
group_vec[nid] = child_group_id;
for (uint32_t nid2 : node_ids_per_group[child_group_id]) {
pattern_vec[nid2] = pattern_vec[nid];
master_vec[nid2] = master_vec[nid];
}
}
}
}
g.attrs["group_root"] = std::make_shared<any>(std::move(group_vec));
g.attrs["group_master"] = std::make_shared<any>(std::move(master_vec));
g.attrs["pattern"] = std::make_shared<any>(std::move(pattern_vec));
......
......@@ -5,7 +5,7 @@ import topi.testing
from tvm.contrib import graph_runtime
from nnvm import symbol as sym
from nnvm.compiler import graph_util, graph_attr
from nnvm.testing import ctx_list
from nnvm.testing import ctx_list, utils
def test_ewise_injective():
x = sym.Variable("x")
......@@ -77,7 +77,49 @@ def test_injective_reduce_injective():
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def build_and_run(sym, params, data, out_shape, target, ctx, opt_level=2):
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(sym, target, shape={"data":data.shape}, params=params)
module = graph_runtime.create(graph, lib, ctx)
module.set_input(**params)
module.set_input("data", data)
module.run()
out = module.get_output(0, tvm.nd.empty(out_shape))
return out.asnumpy(), graph
def test_fuse_conv2d_elu():
def elu(data):
return -0.5 * sym.relu(1 - sym.exp(data)) + sym.relu(data)
def get_sym(out_channel):
data = sym.Variable(name="data")
data = sym.conv2d(data=data, kernel_size=(3,3), channels=out_channel, padding=(1, 1),
layout="NCHW", kernel_layout="OIHW", use_bias=True)
data = sym.batch_norm(data)
data = elu(data)
return data
in_channel = 8
out_channel = 16
size = 64
dshape = (1, in_channel, size, size)
oshape = (1, out_channel, size, size)
data = np.random.uniform(-1, 1, dshape).astype(np.float32)
for target, ctx in ctx_list():
sym1 = get_sym(out_channel)
sym2 = get_sym(out_channel)
_, params1 = utils.create_workload(sym1, 1, dshape[1:], seed=0)
_, params2 = utils.create_workload(sym2, 1, dshape[1:], seed=0)
output1, g1 = build_and_run(sym1, params1, data, oshape, target, ctx, opt_level=2)
output2, g2 = build_and_run(sym2, params2, data, oshape, target, ctx, opt_level=0)
np.testing.assert_allclose(output1, output2, rtol=1e-5, atol=1e-5)
# data, conv weight, bias, batch norm gamma, batch norm beta, conv op
assert g1.index.num_nodes == 6
if __name__ == "__main__":
test_injective_reduce_injective()
test_ewise_injective()
test_conv_ewise_injective()
test_fuse_conv2d_elu()
......@@ -39,11 +39,10 @@ def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
def schedule_conv2d_nchw_arm_cpu(cfg, outs):
"""TOPI schedule callback"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _callback(op):
# schedule conv2d
if 'spatial_conv_output' in op.tag and op not in scheduled_ops:
if 'spatial_conv_output' in op.tag:
output = op.output(0)
conv = op.input_tensors[0]
......@@ -65,8 +64,6 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
output = op.output(0)
_schedule_winograd(cfg, s, output, outs[0])
scheduled_ops.append(op)
traverse_inline(s, outs[0].op, _callback)
return s
......
......@@ -5,26 +5,34 @@ import tvm
from . import tag
def traverse_inline(s, op, callback):
def traverse_inline(s, final_op, callback):
"""Traverse computation graph and do auto inline
Parameters
----------
s: schedule
The schedule
op: Operation
final_op: Operation
The final output operator.
callback: callable
The callback function on each op
"""
visited = set()
def _traverse(op):
if op in visited:
return
visited.add(op)
if tag.is_injective(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse_inline(s, tensor.op, callback)
_traverse(tensor.op)
callback(op)
_traverse(final_op)
def prod(x):
"""Get the product of every items in the tuple.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment