Commit a1404e23 by Eric Junyuan Xie Committed by Tianqi Chen

fix symbol output compose (#166)

parent 35a7eac5
...@@ -268,14 +268,6 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -268,14 +268,6 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames"); static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose"); static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");
for (size_t i = 0; i < outputs.size(); ++i) {
if (outputs[i].node->is_variable()) {
CHECK_EQ(args.size(), 0) << "Variable composition only supports keyword arguments";
const auto it = kwargs.find(outputs[i].node->attrs.name);
if (it != kwargs.end()) outputs[i] = it->second->outputs[0];
}
}
// parameter check. // parameter check.
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i]->outputs.size(), 1U) CHECK_EQ(args[i]->outputs.size(), 1U)
...@@ -407,6 +399,15 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -407,6 +399,15 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
dmlc::BeginPtr(arg_names) + arg_names.size()); dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, arg_names); KeywordArgumentMismatch("Symbol.Compose", keys, arg_names);
} }
// update outputs in case the composed variable is part of outputs.
for (size_t i = 0; i < outputs.size(); ++i) {
if (outputs[i].node->is_variable()) {
CHECK_EQ(args.size(), 0) << "Variable composition only supports keyword arguments";
const auto it = kwargs.find(outputs[i].node->attrs.name);
if (it != kwargs.end()) outputs[i] = it->second->outputs[0];
}
}
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment