Commit f0d8b594 by Eric Junyuan Xie Committed by Tianqi Chen

allow variable composition (#133)

parent 8f83be7a
......@@ -87,8 +87,11 @@ inline std::vector<std::string> GetKeys(
// whether the symbol is atomic functor
inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
return outputs[0].node->inputs.size() == 0 &&
outputs[0].node->control_deps.size() == 0;
Node* node = outputs[0].node.get();
for (const NodeEntry& e : outputs) {
if (node != e.node.get()) return false;
}
return node->inputs.size() == 0 && node->control_deps.size() == 0;
}
// public functions
......@@ -261,7 +264,14 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
for (size_t i = 0; i < outputs.size(); ++i) {
if (outputs[i].node->is_variable()) {
CHECK_EQ(args.size(), 0) << "Variable composition only supports keyword arguments";
const auto it = kwargs.find(outputs[i].node->attrs.name);
if (it != kwargs.end()) outputs[i] = it->second->outputs[0];
}
}
// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i]->outputs.size(), 1U)
......@@ -271,13 +281,13 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
CHECK_EQ(kv.second->outputs.size(), 1U)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
}
// assign new name
outputs[0].node->attrs.name = name;
// Atomic functor composition.
if (IsAtomic(outputs)) {
Node* n = outputs[0].node.get();
uint32_t n_req = n->num_inputs();
// assign new name
if (!name.empty()) n->attrs.name = name;
if (n_req != kVarg) {
n->inputs.resize(n_req);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment