Commit 464c8c26 by Junru Shao Committed by Tianqi Chen

Fix a bug in Symbol::Compose when using subgraphs as input (#1314)

parent baa04599
...@@ -315,19 +315,26 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -315,19 +315,26 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
const Symbol *sym; const Symbol *sym;
if (idx < arg_vec.size()) { if (idx < arg_vec.size()) {
sym = arg_vec[idx]; sym = arg_vec[idx];
arg_vec.erase(arg_vec.begin() + idx);
} else { } else {
auto it = kwarg_map.find(arg_names[idx]); auto it = kwarg_map.find(arg_names[idx]);
CHECK(it != kwarg_map.end()); CHECK(it != kwarg_map.end());
sym = it->second; sym = it->second;
kwarg_map.erase(it); kwarg_map.erase(it);
} }
if (n_req != kVarg) if (n_req != kVarg)
n_req--; n_req--;
arg_names.erase(arg_names.begin() + idx);
n->attrs.subgraphs.push_back(std::make_shared<Symbol>(*sym)); n->attrs.subgraphs.push_back(std::make_shared<Symbol>(*sym));
} }
// Because idxes does not contain duplicates, the loop below functions well.
// Note that it is as slow as O(|idxes| * |args|),
// but given that |idxes| is small, it is just fine
sort(std::begin(idxes), std::end(idxes), std::greater<int>());
for (auto idx : idxes) {
if (idx < arg_vec.size()) {
arg_vec.erase(arg_vec.begin() + idx);
}
arg_names.erase(arg_names.begin() + idx);
}
} }
if (n_req != kVarg) { if (n_req != kVarg) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment