Commit 2b3d2e21 by Tianqi Chen

[PASS] Improve GraphFuse to include five patterns (#26)

parent 2e9b6b99
NNVM Core Primitives
====================
NNVM Core Tensor Operators
==========================
**Level 1: Basic Ops**
**Level 1: Basic Operators**
This level enables fully connected multi-layer perceptron.
.. autosummary::
:nosignatures:
......@@ -12,12 +13,14 @@ NNVM Core Primitives
nnvm.symbol.sigmoid
nnvm.symbol.exp
nnvm.symbol.log
nnvm.symbol.sqrt
nnvm.symbol.elemwise_add
nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div
nnvm.symbol.flatten
nnvm.symbol.concatenate
nnvm.symbol.expand_dims
nnvm.symbol.split
nnvm.symbol.dropout
nnvm.symbol.batch_norm
......@@ -27,6 +30,8 @@ NNVM Core Primitives
**Level 2: Convolutions**
This level enables typical convnet models.
.. autosummary::
:nosignatures:
......@@ -78,12 +83,14 @@ NNVM Core Primitives
.. autofunction:: nnvm.symbol.sigmoid
.. autofunction:: nnvm.symbol.exp
.. autofunction:: nnvm.symbol.log
.. autofunction:: nnvm.symbol.sqrt
.. autofunction:: nnvm.symbol.elemwise_add
.. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims
.. autofunction:: nnvm.symbol.split
.. autofunction:: nnvm.symbol.dropout
.. autofunction:: nnvm.symbol.batch_norm
......
......@@ -25,16 +25,23 @@ using ::tvm::Tensor;
using ::tvm::Schedule;
/*! \brief operator pattern used in graph fusion */
enum OpPatternKind : int {
enum OpPatternKind {
// Elementwise operation
kElemWise = 0,
// Broadcast operation
// Broadcasting operator, can always map output axis to the input in order.
// for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
// Note that the axis need to be in order so transpose is not a bcast operator.
kBroadcast = 1,
// Complex operation, can fuse bcast in input/outputs
// Injective operator, can always injectively map output axis to a single input axis.
// All injective operator can still be safely fused to injective and reduction.
kInjective = 2,
// Communicative reduction operator.
kCommReduce = 3,
// Complex operation, can still fuse elemwise operations into its output.
// but cannot chain another complex op
kComplex = 2,
// Extern operation, cannot fuse anything.
kExtern = 3
kOutEWiseFusable = 4,
// Opaque operation, cannot fuse anything.
kOpaque = 8
};
/*! \brief the operator pattern */
......
......@@ -3,12 +3,24 @@
import tvm
class OpPattern(object):
ELEM_WISE = 0
"""Operator generic patterns
See Also
--------
top.tag : Contains explaination of the tag type.
"""
# Elementwise operator
ELEMWISE = 0
# Broadcast operator
BROADCAST = 1
# Complex means we can fuse elemwise to it
COMPLEX = 2
# Extern means the op is not fusable
EXTERN = 3
# Injective mapping
INJECTIVE = 2
# Comunication
COMM_REDUCE = 3
# Complex op, can still fuse ewise into it
OUT_ELEMWISE_FUSABLE = 4
# Not fusable opaque op
OPAQUE = 8
_register_compute = tvm.get_global_func("nnvm._register_compute")
_register_schedule = tvm.get_global_func("nnvm._register_schedule")
......
......@@ -16,8 +16,16 @@ def compute_relu(attrs, inputs, _):
return topi.nn.relu(inputs[0])
reg.register_schedule("relu", _fschedule_broadcast)
reg.register_pattern("relu", OpPattern.ELEM_WISE)
reg.register_pattern("relu", OpPattern.ELEMWISE)
# leaky_relu
@reg.register_compute("leaky_relu")
def compute_relu(attrs, inputs, _):
"""Compute definition of relu"""
return topi.nn.leaky_relu(inputs[0])
reg.register_schedule("leaky_relu", _fschedule_broadcast)
reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)
# flatten
@reg.register_compute("flatten")
......@@ -26,7 +34,7 @@ def compute_flatten(attrs, inputs, _):
return topi.nn.flatten(inputs[0])
reg.register_schedule("flatten", _fschedule_broadcast)
reg.register_pattern("flatten", OpPattern.COMPLEX)
reg.register_pattern("flatten", OpPattern.INJECTIVE)
# softmax
......@@ -46,7 +54,7 @@ def schedule_softmax(_, outs, target):
return tvm.create_schedule([x.op for x in outs])
# Mark softmax as extern as we do not fuse it in call cases
reg.register_pattern("softmax", OpPattern.EXTERN)
reg.register_pattern("softmax", OpPattern.OPAQUE)
# dense
......@@ -67,7 +75,7 @@ def schedule_dense(_, outs, target):
return tvm.create_schedule([x.op for x in outs])
# register extern for now, change me when fusion is enabled.
reg.register_pattern("dense", OpPattern.EXTERN)
reg.register_pattern("dense", OpPattern.OPAQUE)
# conv
......@@ -105,4 +113,4 @@ def schedule_conv2d(attrs, outs, target):
# naive schedule
return tvm.create_schedule([x.op for x in outs])
reg.register_pattern("conv2d", OpPattern.COMPLEX)
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
......@@ -8,13 +8,15 @@ import topi.cuda
from ..compiler import registry as reg
from ..compiler import OpPattern
def _schedule_broadcast(_, outs, target):
def _schedule_injective(_, outs, target):
"""Generic schedule for binary bcast"""
if target == "cuda":
return topi.cuda.schedule_elemwise(outs)
return topi.cuda.schedule_injective(outs)
assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs])
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
def _compute_binary_scalar(f):
......@@ -42,89 +44,91 @@ def _compute_binary(f):
return _compute
_fschedule_broadcast = tvm.convert(_schedule_broadcast)
_fschedule_injective = tvm.convert(_schedule_injective)
_fschedule_broadcast = _fschedule_injective
_fschedule_elemwise = _fschedule_injective
# copy
reg.register_compute("copy", _compute_unary(topi.identity))
reg.register_pattern("copy", OpPattern.ELEM_WISE)
reg.register_pattern("copy", OpPattern.ELEMWISE)
reg.register_schedule("copy", _fschedule_broadcast)
# exp
reg.register_compute("exp", _compute_unary(topi.exp))
reg.register_pattern("exp", OpPattern.ELEM_WISE)
reg.register_pattern("exp", OpPattern.ELEMWISE)
reg.register_schedule("exp", _fschedule_broadcast)
# sqrt
reg.register_compute("sqrt", _compute_unary(topi.sqrt))
reg.register_pattern("sqrt", OpPattern.ELEM_WISE)
reg.register_pattern("sqrt", OpPattern.ELEMWISE)
reg.register_schedule("sqrt", _fschedule_broadcast)
# log
reg.register_compute("log", _compute_unary(topi.log))
reg.register_pattern("log", OpPattern.ELEM_WISE)
reg.register_pattern("log", OpPattern.ELEMWISE)
reg.register_schedule("log", _fschedule_broadcast)
# tanh
reg.register_compute("tanh", _compute_unary(topi.tanh))
reg.register_pattern("tanh", OpPattern.ELEM_WISE)
reg.register_pattern("tanh", OpPattern.ELEMWISE)
reg.register_schedule("tanh", _fschedule_broadcast)
# negative
reg.register_compute("negative", _compute_unary(topi.negative))
reg.register_pattern("negative", OpPattern.ELEM_WISE)
reg.register_pattern("negative", OpPattern.ELEMWISE)
reg.register_schedule("negative", _fschedule_broadcast)
# sigmoid
reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
reg.register_pattern("sigmoid", OpPattern.ELEM_WISE)
reg.register_pattern("sigmoid", OpPattern.ELEMWISE)
reg.register_schedule("sigmoid", _fschedule_broadcast)
# add_scalar
reg.register_compute("__add_scalar__",
_compute_binary_scalar(lambda x, y: x + y))
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__add_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__add_scalar__", _fschedule_broadcast)
# sub_calar
reg.register_compute("__sub_scalar__",
_compute_binary_scalar(lambda x, y: x - y))
reg.register_pattern("__sub_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__sub_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__sub_scalar__", _fschedule_broadcast)
# rsub_scalar
reg.register_compute("__rsub_scalar__",
_compute_binary_scalar(lambda x, y: y - x))
reg.register_pattern("__rsub_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__rsub_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)
# mul_scalar
reg.register_compute("__mul_scalar__",
_compute_binary_scalar(lambda x, y: x * y))
reg.register_pattern("__mul_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__mul_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__mul_scalar__", _fschedule_broadcast)
# div_scalar
reg.register_compute("__div_scalar__",
_compute_binary_scalar(lambda x, y: x / y))
reg.register_pattern("__div_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__div_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__div_scalar__", _fschedule_broadcast)
# rdiv_scalar
reg.register_compute("__rdiv_scalar__",
_compute_binary_scalar(lambda x, y: y / x))
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)
# pow_scalar
reg.register_compute("__pow_scalar__",
_compute_binary_scalar(tvm.power))
reg.register_pattern("__pow_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__pow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__pow_scalar__", _fschedule_broadcast)
# rpow_scalar
reg.register_compute("__rpow_scalar__",
_compute_binary_scalar(lambda x, y: tvm.power(y, x)))
reg.register_pattern("__rpow_scalar__", OpPattern.ELEM_WISE)
reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE)
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)
# elemwise_add
......
......@@ -37,5 +37,5 @@ def compute_reshape(attrs, inputs, out_info):
oshape = out_info[0].shape
x = inputs[0]
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
reg.register_pattern("reshape", OpPattern.COMPLEX)
reg.register_pattern("reshape", OpPattern.INJECTIVE)
reg.register_schedule("reshape", _fschedule_broadcast)
......@@ -71,7 +71,7 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
ref_count[e.node_id] += 2;
}
// Pattern for the subgraph
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kExtern);
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kOpaque);
// Whether node can be fused to parent.
std::vector<FuseRule> fuse_vec(idx.num_nodes(), FuseRule::kUknown);
// Master node id of fusion segment.
......@@ -84,19 +84,21 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
if (inode.source->is_variable()) {
fuse_vec[nid] = FuseRule::kRealize; continue;
}
TOpPattern pt = op_pattern.get(inode.source->op(), kExtern);
TOpPattern pt = op_pattern.get(inode.source->op(), kOpaque);
if (pt <= kBroadcast) {
// Try to check if we can fuse to the master.
int chosen_master = -1;
bool ewise = inode.source->num_outputs() == 1;
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
TOpPattern ipt = pattern_vec[e.node_id];
if (ipt != kElemWise) ewise = false;
if (ipt <= kBroadcast) {
if (ipt <= kInjective) {
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else if (ipt == kComplex && chosen_master == -1 &&
shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) {
} else if (ipt == kOutEWiseFusable &&
chosen_master == -1 &&
shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) {
chosen_master = master_vec[e.node_id];
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else {
......@@ -111,11 +113,27 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
}
master_vec[nid] = chosen_master;
if (chosen_master != -1) {
pt = kComplex;
pt = kOutEWiseFusable;
} else {
pt = ewise ? kElemWise : kBroadcast;
}
} else if (pt == kInjective || pt == kCommReduce) {
// fuse to the comm reduce or injective
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
TOpPattern ipt = pattern_vec[e.node_id];
if (ipt <= kInjective) {
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else {
fuse_vec[e.node_id] = FuseRule::kRealize;
}
}
}
if (pt == kCommReduce) {
master_vec[nid] = nid;
}
} else {
// realize
master_vec[nid] = nid;
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
......@@ -136,7 +154,6 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
}
}
// point to the group root id of each node
std::vector<int> group_vec(idx.num_nodes(), -1);
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
......
......@@ -52,7 +52,7 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
// use op pattern to decide whether an op is map
auto is_map_op = [&](size_t nid) {
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kExtern);
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque);
bool is_map = (pt <= kBroadcast);
if (pt == kBroadcast) {
for (const auto& e : idx[nid].inputs) {
......
import nnvm
import numpy as np
import tvm
import topi
from nnvm import symbol as sym
from nnvm.compiler import graph_util, graph_attr
from nnvm.testing.config import test_ctx_list
def test_ewise_injective():
x = sym.Variable("x")
y = x * 2
y = sym.flatten(y) + 1
dshape = (10, 2, 3)
shape_dict = {"x": dshape}
dtype = "float32"
target = "llvm"
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
assert graph.index.num_nodes == 2
m = nnvm.runtime.create(graph, lib, ctx)
x_np = np.random.uniform(size=dshape).astype(dtype)
m.run(x=x_np)
out = m.get_output(0, tvm.nd.empty((10, 6)))
np.testing.assert_allclose(
out.asnumpy(), x_np.reshape(out.shape) * 2 + 1,
atol=1e-5, rtol=1e-5)
def test_conv_ewise_injective():
x = sym.Variable("x")
y = sym.conv2d(x, channels=32, kernel_size=(3, 3), groups=32,
name="y", padding=(1,1))
y = sym.flatten(y + 1) + 1
dtype = "float32"
dshape = (1, 32, 18, 18)
kshape = (32, 1, 3, 3)
oshape = (1, 32* 18 * 18)
shape_dict = {"x": dshape}
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
# print(graph.ir(join_entry_attrs=["shape"]))
assert graph.index.num_nodes == 5
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype))
m.run(x=data, y_weight=kernel, y_bias=bias)
# get output
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
c_np = topi.testing.depthwise_conv2d_python_nchw(
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1) + 1
c_np = c_np.reshape(c_np.shape[0], np.prod(c_np.shape[1:])) + 1
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
if __name__ == "__main__":
test_ewise_injective()
test_conv_ewise_injective()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment