Commit a53d8d01 by Tianqi Chen

[PASS] Enhance scale fold axis (#424)

parent 89c124bc
...@@ -18,12 +18,10 @@ namespace compiler { ...@@ -18,12 +18,10 @@ namespace compiler {
enum FoldScaleKind { enum FoldScaleKind {
// No folding is applied // No folding is applied
kNone, kNone,
// The folding decision is pending // The folding decision is pending, we can fold on a state.
kPending, kPending,
// The original operator that contains the scale. // The original operator that contains the scale.
kProvider, kProvider,
// Pass through the scale to parent/child to the first axis.
kPassTroughFirst,
// The final conumer of axis scale using multiply // The final conumer of axis scale using multiply
// Likely be a conv or dense operator. // Likely be a conv or dense operator.
kMulConsumer, kMulConsumer,
...@@ -31,21 +29,23 @@ enum FoldScaleKind { ...@@ -31,21 +29,23 @@ enum FoldScaleKind {
kDivConsumer kDivConsumer
}; };
// Input fold information struct FoldChainInfo {
struct FoldScaleInput {
uint32_t index;
int axis;
};
// The entry of folding chains on which
// we should perform folding on
struct FoldChainEntry {
// Entry kind // Entry kind
FoldScaleKind kind{kNone}; FoldScaleKind kind{kNone};
// The output axis to be folded // The output axis to be folded
int axis{0}; int axis{0};
// Source node in the fold chain // Source node in the fold chain
int source{0}; int source{0};
};
// The entry of folding chains on which
// we should perform folding on
struct FoldChainEntry {
// Fold information
FoldChainInfo info;
// Number of outgoing fork count
// in forward propagation.
int fork_count{0};
// Following field only used by provider. // Following field only used by provider.
// The input index // The input index
int fold_input_index{1}; int fold_input_index{1};
...@@ -55,12 +55,26 @@ struct FoldChainEntry { ...@@ -55,12 +55,26 @@ struct FoldChainEntry {
// Try to pass axis scaling to backward, // Try to pass axis scaling to backward,
// Given that we we know the status of current fold axis. // Given that we we know the status of current fold axis.
// return whether the forward signal is consumed.
using FScaleAxisBackward = std::function< using FScaleAxisBackward = std::function<
FoldScaleKind(const NodeAttrs& attrs, bool(const NodeAttrs& attrs,
int axis, const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
const FoldChainInfo& out_info,
std::vector<FoldChainInfo>* in_info)>;
// Try to pass axis scaling to forward,
// Given that we we know the status of one of its input to be pending
// also update other input info
// return whether the forward signal is consumed.
using FScaleAxisForward = std::function<
bool(const NodeAttrs& attrs,
const std::vector<TShape>& in_shape, const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape, const std::vector<TShape>& out_shape,
std::vector<std::pair<uint32_t, int> >* in_axis)>; std::vector<FoldChainInfo>* in_info,
FoldChainInfo* out_info)>;
// Detect if there is a scaling axis happening // Detect if there is a scaling axis happening
bool DetectScaleAxis(const IndexedGraph& idx, bool DetectScaleAxis(const IndexedGraph& idx,
...@@ -99,15 +113,19 @@ bool DetectScaleAxis(const IndexedGraph& idx, ...@@ -99,15 +113,19 @@ bool DetectScaleAxis(const IndexedGraph& idx,
} else { } else {
return false; return false;
} }
e.axis = axis.first; e.info.axis = axis.first;
e.kind = kPending; e.info.kind = kPending;
e.source = nid; e.info.source = nid;
e.fork_count = 1;
// In the backward message passing
// We need to eagerly pass it to the input
// In the forward message passing
// we will "pull" the message from input.
if (!is_forward) { if (!is_forward) {
// pass message to another input
FoldChainEntry& enext = (*chain)[b.node_id]; FoldChainEntry& enext = (*chain)[b.node_id];
enext.axis = e.axis; enext.info.axis = e.info.axis;
enext.kind = kPending; enext.info.kind = kPending;
enext.source = nid; enext.info.source = nid;
} }
return true; return true;
} }
...@@ -119,12 +137,16 @@ Graph FoldScaleAxis(Graph src) { ...@@ -119,12 +137,16 @@ Graph FoldScaleAxis(Graph src) {
// Operator pattern // Operator pattern
static auto& fbackward = static auto& fbackward =
nnvm::Op::GetAttr<FScaleAxisBackward>("FScaleAxisBackward"); nnvm::Op::GetAttr<FScaleAxisBackward>("FScaleAxisBackward");
static auto& fforward =
nnvm::Op::GetAttr<FScaleAxisForward>("FScaleAxisForward");
const IndexedGraph& idx = src.indexed_graph(); const IndexedGraph& idx = src.indexed_graph();
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape"); const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
std::vector<uint32_t> ref_count = GetNodeRefCounts(idx); std::vector<uint32_t> ref_count = GetNodeRefCounts(idx);
std::vector<FoldChainEntry> bwd_chain(idx.num_nodes()); std::vector<FoldChainEntry> bwd_chain(idx.num_nodes());
std::vector<FoldChainEntry> fwd_chain(idx.num_nodes());
// shape hint for the inference. // shape hint for the inference.
std::vector<TShape> in_shape, out_shape; std::vector<TShape> in_shape, out_shape;
// perform backward folding. // perform backward folding.
for (uint32_t i = idx.num_nodes(); i != 0; --i) { for (uint32_t i = idx.num_nodes(); i != 0; --i) {
uint32_t nid = i - 1; uint32_t nid = i - 1;
...@@ -132,9 +154,10 @@ Graph FoldScaleAxis(Graph src) { ...@@ -132,9 +154,10 @@ Graph FoldScaleAxis(Graph src) {
if (inode.source->is_variable()) continue; if (inode.source->is_variable()) continue;
if (DetectScaleAxis(idx, nid, shape_vec, if (DetectScaleAxis(idx, nid, shape_vec,
ref_count, false, &bwd_chain)) continue; ref_count, false, &bwd_chain)) continue;
if (bwd_chain[nid].kind != kPending) continue; if (bwd_chain[nid].info.kind != kPending) continue;
// if referred by multiple node, cannot do propagation
if (ref_count[nid] != 1 || !fbackward.count(inode.source->op())) { if (ref_count[nid] != 1 || !fbackward.count(inode.source->op())) {
bwd_chain[nid].kind = kNone; continue; bwd_chain[nid].info.kind = kNone; continue;
} }
// get input shape and output shape. // get input shape and output shape.
in_shape.clear(); out_shape.clear(); in_shape.clear(); out_shape.clear();
...@@ -144,58 +167,151 @@ Graph FoldScaleAxis(Graph src) { ...@@ -144,58 +167,151 @@ Graph FoldScaleAxis(Graph src) {
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
out_shape.push_back(shape_vec[idx.entry_id(nid, i)]); out_shape.push_back(shape_vec[idx.entry_id(nid, i)]);
} }
std::vector<std::pair<uint32_t, int> > in_axis; std::vector<FoldChainInfo> in_info(in_shape.size(), FoldChainInfo());
FoldScaleKind kind = bool consumed = fbackward[inode.source->op()](
fbackward[inode.source->op()]( inode.source->attrs,
inode.source->attrs, bwd_chain[nid].axis, in_shape,
in_shape, out_shape, &in_axis); out_shape,
bwd_chain[nid].kind = kind; bwd_chain[nid].info,
if (kind == kNone) continue; &in_info);
CHECK_GE(in_axis.size(), 1U); CHECK_EQ(in_info.size(), in_shape.size());
CHECK(kind == kPassTroughFirst || kind == kMulConsumer);
// propagate back. // propagate back.
bool can_prop = true; bool can_prop = true;
for (size_t i = 0; i < in_axis.size(); ++i) { for (size_t i = 0; i < in_info.size(); ++i) {
const IndexedGraph::NodeEntry& e = inode.inputs[in_axis[0].first]; const IndexedGraph::NodeEntry& e = inode.inputs[i];
if (ref_count[e.node_id] != 1 || if (ref_count[e.node_id] != 1 ||
idx[e.node_id].source->num_outputs() != 1) { idx[e.node_id].source->num_outputs() != 1) {
can_prop = false; break; can_prop = false; break;
} }
} }
if (!can_prop) continue; if (!can_prop) continue;
for (size_t i = 0; i < in_axis.size(); ++i) { for (size_t i = 0; i < in_info.size(); ++i) {
const IndexedGraph::NodeEntry& e = inode.inputs[in_axis[i].first]; const IndexedGraph::NodeEntry& e = inode.inputs[i];
if (kind == kPassTroughFirst && i == 0) { bwd_chain[e.node_id].info = in_info[i];
bwd_chain[e.node_id].kind = kPending; }
// mark consumed by making the source as provider.
if (consumed) {
bwd_chain[bwd_chain[nid].info.source].info.kind = kProvider;
}
}
// perform forward folding.
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
// skip scales that are already folded in backward.
if (bwd_chain[nid].info.kind == kProvider) continue;
if (DetectScaleAxis(idx, nid, shape_vec,
ref_count, true, &fwd_chain)) continue;
if (inode.source->num_outputs() != 1) continue;
// Do state update
// get input shape and output shape.
std::vector<FoldChainInfo> in_info;
FoldChainInfo out_info;
int num_inpending = 0;
in_shape.clear(); out_shape.clear();
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
in_shape.push_back(shape_vec[idx.entry_id(e)]);
// input information
in_info.push_back(fwd_chain[e.node_id].info);
if (fwd_chain[e.node_id].info.kind == kPending) {
++num_inpending;
}
}
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
out_shape.push_back(shape_vec[idx.entry_id(nid, i)]);
}
if (num_inpending != 1 ||
!fforward.count(inode.source->op())) continue;
bool consumed = fforward[inode.source->op()](
inode.source->attrs,
in_shape,
out_shape,
&in_info,
&out_info);
// update input info
for (size_t i = 0; i < in_info.size(); ++i) {
fwd_chain[inode.inputs[i].node_id].info = in_info[i];
}
if (consumed) {
fwd_chain[nid].info = out_info;
for (size_t i = 0; i < in_info.size(); ++i) {
if (in_info[i].kind == kPending) {
if (--fwd_chain[in_info[i].source].fork_count == 0) {
fwd_chain[in_info[i].source].info.kind = kProvider;
}
}
}
} else { } else {
bwd_chain[nid].kind = kNone; // can propagate condition
bwd_chain[e.node_id].kind = kMulConsumer; if (inode.source->num_outputs() == 1) {
fwd_chain[nid].info = out_info;
if (out_info.kind == kPending) {
// When there is multiple reference to input
// every path have to be consumed
fwd_chain[out_info.source].fork_count += ref_count[nid] - 1;
} }
bwd_chain[e.node_id].axis = in_axis[i].second;
bwd_chain[e.node_id].source = bwd_chain[nid].source;
} }
if (kind == kMulConsumer) {
bwd_chain[bwd_chain[nid].source].kind = kProvider;
} }
} }
auto transform = [&](uint32_t nid, const NodePtr& n, std::vector<NodeEntry>* ret) { auto transform = [&](uint32_t nid, const NodePtr& n, std::vector<NodeEntry>* ret) {
NodeEntry rvalue = NodeEntry{n, 0, 0};
{
// Backward chain
const FoldChainEntry& e = bwd_chain[nid]; const FoldChainEntry& e = bwd_chain[nid];
if (e.kind == kMulConsumer && bwd_chain[e.source].kind == kProvider) { if (e.info.kind == kMulConsumer &&
const FoldChainEntry& se = bwd_chain[e.source]; bwd_chain[e.info.source].info.kind == kProvider) {
const FoldChainEntry& se = bwd_chain[e.info.source];
CHECK_EQ(n->num_outputs(), 1); CHECK_EQ(n->num_outputs(), 1);
NodeEntry scale = ExpandBiasToMatchAxis( NodeEntry scale = ExpandBiasToMatchAxis(
se.scale_entry, se.scale_entry,
shape_vec[idx.entry_id(nid, 0)].ndim(), shape_vec[idx.entry_id(nid, 0)].ndim(),
shape_vec[idx.entry_id(se.scale_entry)].ndim(), shape_vec[idx.entry_id(se.scale_entry)].ndim(),
e.axis); e.info.axis);
*ret = {MakeNode("broadcast_mul", n->attrs.name + "_sc", rvalue = MakeNode("broadcast_mul", n->attrs.name + "_sc",
{NodeEntry{n, 0, 0}, scale})}; {rvalue, scale});
return true; } else if (e.info.kind == kProvider) {
} else if (e.kind == kProvider) { rvalue = n->inputs[e.fold_input_index];
*ret = {n->inputs[e.fold_input_index]}; }
return true; }
} else { // Note that the value might get transformed twice if it
// folds value from both fwd and backward chain.
{
// forward chain
const FoldChainEntry& e = fwd_chain[nid];
if (e.info.kind == kMulConsumer &&
fwd_chain[e.info.source].info.kind == kProvider) {
const FoldChainEntry& se = fwd_chain[e.info.source];
CHECK_EQ(n->num_outputs(), 1);
NodeEntry scale = ExpandBiasToMatchAxis(
se.scale_entry,
shape_vec[idx.entry_id(nid, 0)].ndim(),
shape_vec[idx.entry_id(se.scale_entry)].ndim(),
e.info.axis);
rvalue = MakeNode("broadcast_mul", n->attrs.name + "_sc",
{rvalue, scale});
} else if (e.info.kind == kDivConsumer &&
fwd_chain[e.info.source].info.kind == kProvider) {
const FoldChainEntry& se = fwd_chain[e.info.source];
CHECK_EQ(n->num_outputs(), 1);
NodeEntry scale = ExpandBiasToMatchAxis(
se.scale_entry,
shape_vec[idx.entry_id(nid, 0)].ndim(),
shape_vec[idx.entry_id(se.scale_entry)].ndim(),
e.info.axis);
rvalue = MakeNode("broadcast_div", n->attrs.name + "_sc",
{rvalue, scale});
} else if (e.info.kind == kProvider) {
rvalue = n->inputs[e.fold_input_index];
}
}
if (rvalue.node == n) {
return false; return false;
} else {
*ret = {rvalue};
return true;
} }
}; };
return GraphTransform(src, transform); return GraphTransform(src, transform);
...@@ -205,14 +321,24 @@ NNVM_REGISTER_PASS(FoldScaleAxis) ...@@ -205,14 +321,24 @@ NNVM_REGISTER_PASS(FoldScaleAxis)
.set_body(FoldScaleAxis); .set_body(FoldScaleAxis);
// property registration. // property registration.
FoldScaleKind ReluScaleAxisBackward( bool ReluScaleAxisBackward(
const NodeAttrs& attrs,
const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
const FoldChainInfo& out_info,
std::vector<FoldChainInfo>* in_axis) {
(*in_axis)[0] = out_info;
return false;
}
bool ReluScaleAxisForward(
const NodeAttrs& attrs, const NodeAttrs& attrs,
int axis,
const std::vector<TShape>& in_shape, const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape, const std::vector<TShape>& out_shape,
std::vector<std::pair<uint32_t, int> >* in_axis) { std::vector<FoldChainInfo>* in_info,
in_axis->emplace_back(0, axis); FoldChainInfo* out_info) {
return kPassTroughFirst; *out_info = (*in_info)[0];
return false;
} }
NNVM_REGISTER_OP(relu) NNVM_REGISTER_OP(relu)
...@@ -221,21 +347,102 @@ NNVM_REGISTER_OP(relu) ...@@ -221,21 +347,102 @@ NNVM_REGISTER_OP(relu)
NNVM_REGISTER_OP(leaky_relu) NNVM_REGISTER_OP(leaky_relu)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", ReluScaleAxisBackward); .set_attr<FScaleAxisBackward>("FScaleAxisBackward", ReluScaleAxisBackward);
FoldScaleKind BroadcastAddSubScaleAxisBackward( NNVM_REGISTER_OP(relu)
.set_attr<FScaleAxisForward>("FScaleAxisForward", ReluScaleAxisForward);
NNVM_REGISTER_OP(leaky_relu)
.set_attr<FScaleAxisForward>("FScaleAxisForward", ReluScaleAxisForward);
// property registration.
bool Pool2DBackward(
const NodeAttrs& attrs,
const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
const FoldChainInfo& out_info,
std::vector<FoldChainInfo>* in_axis) {
using top::Pool2DParam;
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
if (out_info.axis == 1 && param.layout == top::kNCHW) {
(*in_axis)[0] = out_info;
}
return false;
}
bool Pool2DForward(
const NodeAttrs& attrs, const NodeAttrs& attrs,
int axis,
const std::vector<TShape>& in_shape, const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape, const std::vector<TShape>& out_shape,
std::vector<std::pair<uint32_t, int> >* in_axis) { std::vector<FoldChainInfo>* in_info,
FoldChainInfo* out_info) {
using top::Pool2DParam;
const Pool2DParam& param = nnvm::get<Pool2DParam>(attrs.parsed);
if ((*in_info)[0].axis == 1 && param.layout == top::kNCHW) {
*out_info = (*in_info)[0];
}
return false;
}
NNVM_REGISTER_OP(max_pool2d)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", Pool2DBackward);
NNVM_REGISTER_OP(avg_pool2d)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", Pool2DBackward);
NNVM_REGISTER_OP(max_pool2d)
.set_attr<FScaleAxisForward>("FScaleAxisForward", Pool2DForward);
NNVM_REGISTER_OP(avg_pool2d)
.set_attr<FScaleAxisForward>("FScaleAxisForward", Pool2DForward);
bool BroadcastAddSubScaleAxisBackward(
const NodeAttrs& attrs,
const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
const FoldChainInfo& out_info,
std::vector<FoldChainInfo>* in_axis) {
if (out_info.kind != kPending) return false;
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
std::pair<int, int> m = MatchBroadcast1DAxis(out_shape[0], in_shape[i]); std::pair<int, int> m = MatchBroadcast1DAxis(out_shape[0], in_shape[1 - i]);
if (m.second != -1 && in_shape[1 - i] == out_shape[0]) { if (m.second != -1 &&
in_axis->emplace_back(i, axis); in_shape[i] == out_shape[0] &&
in_axis->emplace_back(1 - i, m.second); m.first == out_info.axis) {
return kPassTroughFirst; (*in_axis)[i].kind = kPending;
(*in_axis)[i].axis = out_info.axis;
(*in_axis)[i].source = out_info.source;
(*in_axis)[1 - i].kind = kMulConsumer;
(*in_axis)[1 - i].axis = m.second;
(*in_axis)[1 - i].source = out_info.source;
return false;
} }
} }
return kNone; return false;
}
bool BroadcastAddSubScaleAxisForward(
const NodeAttrs& attrs,
const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
std::vector<FoldChainInfo>* in_info,
FoldChainInfo* out_info) {
for (int i = 0; i < 2; ++i) {
if ((*in_info)[i].kind == kPending) {
std::pair<int, int> m = MatchBroadcast1DAxis(out_shape[0], in_shape[1 - i]);
if (m.second != -1 &&
in_shape[i] == out_shape[0] &&
m.first == (*in_info)[i].axis) {
out_info->kind = kPending;
out_info->axis = m.first;
out_info->source = (*in_info)[i].source;
(*in_info)[1 - i].kind = kDivConsumer;
(*in_info)[1 - i].axis = m.second;
(*in_info)[1 - i].source = (*in_info)[i].source;
return false;
}
}
}
return false;
} }
NNVM_REGISTER_OP(broadcast_add) NNVM_REGISTER_OP(broadcast_add)
...@@ -244,28 +451,62 @@ NNVM_REGISTER_OP(broadcast_add) ...@@ -244,28 +451,62 @@ NNVM_REGISTER_OP(broadcast_add)
NNVM_REGISTER_OP(broadcast_sub) NNVM_REGISTER_OP(broadcast_sub)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward); .set_attr<FScaleAxisBackward>("FScaleAxisBackward", BroadcastAddSubScaleAxisBackward);
FoldScaleKind Conv2DScaleAxisBackward( NNVM_REGISTER_OP(broadcast_add)
.set_attr<FScaleAxisForward>("FScaleAxisForward", BroadcastAddSubScaleAxisForward);
NNVM_REGISTER_OP(broadcast_sub)
.set_attr<FScaleAxisForward>("FScaleAxisForward", BroadcastAddSubScaleAxisForward);
bool Conv2DScaleAxisBackward(
const NodeAttrs& attrs, const NodeAttrs& attrs,
int axis,
const std::vector<TShape>& in_shape, const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape, const std::vector<TShape>& out_shape,
std::vector<std::pair<uint32_t, int> >* in_axis) { const FoldChainInfo& out_info,
std::vector<FoldChainInfo>* in_axis) {
using top::Conv2DParam; using top::Conv2DParam;
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed); const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if (out_info.kind != kPending) return false;
// only optimize for nchw for now // only optimize for nchw for now
if (param.layout == top::kNCHW) { if (param.layout == top::kNCHW && out_info.axis == 1) {
in_axis->emplace_back(1, 0); (*in_axis)[1].kind = kMulConsumer;
(*in_axis)[1].axis = 0;
(*in_axis)[1].source = out_info.source;
if (param.use_bias) { if (param.use_bias) {
in_axis->emplace_back(2, 0); (*in_axis)[2].kind = kMulConsumer;
(*in_axis)[2].axis = 0;
(*in_axis)[2].source = out_info.source;
}
return true;
} else {
return false;
} }
return kMulConsumer; }
bool Conv2DScaleAxisForward(
const NodeAttrs& attrs,
const std::vector<TShape>& in_shape,
const std::vector<TShape>& out_shape,
std::vector<FoldChainInfo>* in_info,
FoldChainInfo* out_info) {
using top::Conv2DParam;
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if ((*in_info)[0].kind != kPending) return false;
// only optimize for nchw for now
if (param.layout == top::kNCHW && (*in_info)[0].axis == 1) {
(*in_info)[1].kind = kMulConsumer;
(*in_info)[1].axis = 1;
(*in_info)[1].source = (*in_info)[0].source;
return true;
} else { } else {
return kNone; return false;
} }
} }
NNVM_REGISTER_OP(conv2d) NNVM_REGISTER_OP(conv2d)
.set_attr<FScaleAxisBackward>("FScaleAxisBackward", Conv2DScaleAxisBackward); .set_attr<FScaleAxisBackward>("FScaleAxisBackward", Conv2DScaleAxisBackward);
NNVM_REGISTER_OP(conv2d)
.set_attr<FScaleAxisForward>("FScaleAxisForward", Conv2DScaleAxisForward);
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
...@@ -196,7 +196,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, ...@@ -196,7 +196,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
if (taken[kv.first] == false && if (taken[kv.first] == false &&
sid_out == GraphAllocator::kBadStorageID && sid_out == GraphAllocator::kBadStorageID &&
sid_in >= 0 && sid_in >= 0 &&
(storage_ref_count[sid_in] == 1 && !ignore_all_inputs || identity[ipair]) && ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || identity[ipair]) &&
entry_ref_count[eid_out] > 0 && entry_ref_count[eid_out] > 0 &&
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
dtype_vec[eid_out] == dtype_vec[eid_in]) { dtype_vec[eid_out] == dtype_vec[eid_in]) {
......
"""Unittest cases for fold_axis""" """Unittest cases for fold_axis"""
import nnvm import nnvm
import nnvm.testing.resnet
import numpy as np
from nnvm import symbol as sym from nnvm import symbol as sym
from nnvm.compiler import graph_util, graph_attr from nnvm.compiler import graph_util, graph_attr
def test_fold_axis_conv(): def test_fold_axis_conv():
def before(x, conv_weight, conv_bias, scale, channels): def before(x, conv_weight, conv_bias, in_scale, out_scale, channels):
x = x * sym.expand_dims(in_scale, axis=1, num_newaxis=2)
y = sym.conv2d(x, conv_weight, conv_bias, y = sym.conv2d(x, conv_weight, conv_bias,
channels=channels, channels=channels,
kernel_size=(3, 3), kernel_size=(3, 3),
padding=(1, 1), padding=(1, 1),
name="conv") name="conv")
y = sym.relu(y) y = sym.relu(y)
y = y * sym.expand_dims(scale, axis=1, num_newaxis=2) y = y * sym.expand_dims(out_scale, axis=1, num_newaxis=2)
return y return y
def expected(x, conv_weight, conv_bias, scale, channels): def expected(x, conv_weight, conv_bias, in_scale, out_scale, channels):
conv_weight = conv_weight * sym.expand_dims(scale, axis=1, num_newaxis=3) conv_weight = conv_weight * sym.expand_dims(out_scale, axis=1, num_newaxis=3)
conv_bias = conv_bias * scale conv_weight = conv_weight * sym.expand_dims(in_scale, axis=1, num_newaxis=2)
conv_bias = conv_bias * out_scale
y = sym.conv2d(x, y = sym.conv2d(x,
conv_weight, conv_weight,
conv_bias, conv_bias,
...@@ -32,10 +36,11 @@ def test_fold_axis_conv(): ...@@ -32,10 +36,11 @@ def test_fold_axis_conv():
x = sym.Variable("x") + 1 x = sym.Variable("x") + 1
weight = sym.Variable("weight") weight = sym.Variable("weight")
bias = sym.Variable("bias") bias = sym.Variable("bias")
scale = sym.Variable("scale") in_scale = sym.Variable("in_scale")
y1 = before(x, weight, bias, scale, channels) out_scale = sym.Variable("out_scale")
y2 = expected(x, weight, bias, scale, channels) y1 = before(x, weight, bias, in_scale, out_scale, channels)
ishape = {"x": shape, "scale": (channels,)} y2 = expected(x, weight, bias, in_scale, out_scale, channels)
ishape = {"x": shape, "out_scale": (channels,), "in_scale": (shape[1],)}
g1 = nnvm.graph.create(y1) g1 = nnvm.graph.create(y1)
g2 = nnvm.graph.create(y2) g2 = nnvm.graph.create(y2)
graph_attr.set_shape_inputs(g1, ishape) graph_attr.set_shape_inputs(g1, ishape)
...@@ -45,5 +50,61 @@ def test_fold_axis_conv(): ...@@ -45,5 +50,61 @@ def test_fold_axis_conv():
check((2, 4, 10, 10), 2) check((2, 4, 10, 10), 2)
def test_fold_fail():
def before(x, scale, channels):
y = sym.conv2d(x,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
name="conv")
y = y * sym.expand_dims(scale, axis=1, num_newaxis=1)
return y
# Before simplify
def check(shape, channels):
x = sym.Variable("x")
bias = sym.Variable("bias")
scale = sym.Variable("scale")
y1 = before(x, scale, channels)
ishape = {"x": shape, "scale": (channels,), "bias": (channels,)}
g1 = nnvm.graph.create(y1)
graph_attr.set_shape_inputs(g1, ishape)
g2 = g1.apply("InferShape").apply("FoldScaleAxis")
# assert graph equals as expected
graph_util.check_graph_equal(g1, g2)
check((2, 10, 10, 10), 10)
def test_fold_resnet():
batch_size = 1
num_classes = 1000
image_shape = (3, 224, 224)
data_shape = (batch_size,) +image_shape
net, params = nnvm.testing.resnet.get_workload(
batch_size=1, image_shape=image_shape)
ishape = {"data" : data_shape}
graph = nnvm.graph.create(net)
data = np.random.uniform(size=data_shape).astype("float32")
# Initial pass do shape type inference
shape, _ = graph_util.infer_shape(graph, **ishape)
ishape.update(zip(graph.index.input_names, shape))
def run_prune(graph, params, opt_level):
# Apply optimization
with nnvm.compiler.build_config(opt_level=0):
graph = nnvm.compiler.optimize(graph, ishape)
graph, params = nnvm.compiler.build_module.precompute_prune(graph, params)
params["data"] = data
return nnvm.compiler.build_module._run_graph(graph, params)
x = run_prune(graph, params, 0)
y = run_prune(graph, params, 3)
np.testing.assert_allclose(y[0].asnumpy(), x[0].asnumpy())
if __name__ == "__main__": if __name__ == "__main__":
test_fold_resnet()
test_fold_axis_conv() test_fold_axis_conv()
test_fold_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment