Commit a53d8d01 by Tianqi Chen

[PASS] Enhance scale fold axis (#424)

parent 89c124bc
...@@ -196,7 +196,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, ...@@ -196,7 +196,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
if (taken[kv.first] == false && if (taken[kv.first] == false &&
sid_out == GraphAllocator::kBadStorageID && sid_out == GraphAllocator::kBadStorageID &&
sid_in >= 0 && sid_in >= 0 &&
(storage_ref_count[sid_in] == 1 && !ignore_all_inputs || identity[ipair]) && ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || identity[ipair]) &&
entry_ref_count[eid_out] > 0 && entry_ref_count[eid_out] > 0 &&
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
dtype_vec[eid_out] == dtype_vec[eid_in]) { dtype_vec[eid_out] == dtype_vec[eid_in]) {
......
"""Unittest cases for fold_axis""" """Unittest cases for fold_axis"""
import nnvm import nnvm
import nnvm.testing.resnet
import numpy as np
from nnvm import symbol as sym from nnvm import symbol as sym
from nnvm.compiler import graph_util, graph_attr from nnvm.compiler import graph_util, graph_attr
def test_fold_axis_conv(): def test_fold_axis_conv():
def before(x, conv_weight, conv_bias, scale, channels): def before(x, conv_weight, conv_bias, in_scale, out_scale, channels):
x = x * sym.expand_dims(in_scale, axis=1, num_newaxis=2)
y = sym.conv2d(x, conv_weight, conv_bias, y = sym.conv2d(x, conv_weight, conv_bias,
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1), padding=(1, 1),
name="conv") name="conv")
y = sym.relu(y) y = sym.relu(y)
y = y * sym.expand_dims(scale, axis=1, num_newaxis=2) y = y * sym.expand_dims(out_scale, axis=1, num_newaxis=2)
return y return y
def expected(x, conv_weight, conv_bias, scale, channels): def expected(x, conv_weight, conv_bias, in_scale, out_scale, channels):
conv_weight = conv_weight * sym.expand_dims(scale, axis=1, num_newaxis=3) conv_weight = conv_weight * sym.expand_dims(out_scale, axis=1, num_newaxis=3)
conv_bias = conv_bias * scale conv_weight = conv_weight * sym.expand_dims(in_scale, axis=1, num_newaxis=2)
conv_bias = conv_bias * out_scale
y = sym.conv2d(x, y = sym.conv2d(x,
conv_weight, conv_weight,
conv_bias, conv_bias,
...@@ -32,10 +36,11 @@ def test_fold_axis_conv(): ...@@ -32,10 +36,11 @@ def test_fold_axis_conv():
x = sym.Variable("x") + 1 x = sym.Variable("x") + 1
weight = sym.Variable("weight") weight = sym.Variable("weight")
bias = sym.Variable("bias") bias = sym.Variable("bias")
scale = sym.Variable("scale") in_scale = sym.Variable("in_scale")
y1 = before(x, weight, bias, scale, channels) out_scale = sym.Variable("out_scale")
y2 = expected(x, weight, bias, scale, channels) y1 = before(x, weight, bias, in_scale, out_scale, channels)
ishape = {"x": shape, "scale": (channels,)} y2 = expected(x, weight, bias, in_scale, out_scale, channels)
ishape = {"x": shape, "out_scale": (channels,), "in_scale": (shape[1],)}
g1 = nnvm.graph.create(y1) g1 = nnvm.graph.create(y1)
g2 = nnvm.graph.create(y2) g2 = nnvm.graph.create(y2)
graph_attr.set_shape_inputs(g1, ishape) graph_attr.set_shape_inputs(g1, ishape)
...@@ -45,5 +50,61 @@ def test_fold_axis_conv(): ...@@ -45,5 +50,61 @@ def test_fold_axis_conv():
check((2, 4, 10, 10), 2) check((2, 4, 10, 10), 2)
def test_fold_fail():
def before(x, scale, channels):
y = sym.conv2d(x,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
name="conv")
y = y * sym.expand_dims(scale, axis=1, num_newaxis=1)
return y
# Before simplify
def check(shape, channels):
x = sym.Variable("x")
bias = sym.Variable("bias")
scale = sym.Variable("scale")
y1 = before(x, scale, channels)
ishape = {"x": shape, "scale": (channels,), "bias": (channels,)}
g1 = nnvm.graph.create(y1)
graph_attr.set_shape_inputs(g1, ishape)
g2 = g1.apply("InferShape").apply("FoldScaleAxis")
# assert graph equals as expected
graph_util.check_graph_equal(g1, g2)
check((2, 10, 10, 10), 10)
def test_fold_resnet():
batch_size = 1
num_classes = 1000
image_shape = (3, 224, 224)
data_shape = (batch_size,) +image_shape
net, params = nnvm.testing.resnet.get_workload(
batch_size=1, image_shape=image_shape)
ishape = {"data" : data_shape}
graph = nnvm.graph.create(net)
data = np.random.uniform(size=data_shape).astype("float32")
# Initial pass do shape type inference
shape, _ = graph_util.infer_shape(graph, **ishape)
ishape.update(zip(graph.index.input_names, shape))
def run_prune(graph, params, opt_level):
# Apply optimization
with nnvm.compiler.build_config(opt_level=0):
graph = nnvm.compiler.optimize(graph, ishape)
graph, params = nnvm.compiler.build_module.precompute_prune(graph, params)
params["data"] = data
return nnvm.compiler.build_module._run_graph(graph, params)
x = run_prune(graph, params, 0)
y = run_prune(graph, params, 3)
np.testing.assert_allclose(y[0].asnumpy(), x[0].asnumpy())
if __name__ == "__main__": if __name__ == "__main__":
test_fold_resnet()
test_fold_axis_conv() test_fold_axis_conv()
test_fold_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment