Commit 36ea5392 by nhynes Committed by Tianqi Chen

Only warn when unable to find a graph input (#1052)

parent cc7a8fcf
......@@ -88,7 +88,7 @@ class GraphRuntime : public ModuleNode {
return static_cast<int>(i);
}
}
LOG(FATAL) << "cannot find " << name << " among input";
LOG(WARNING) << "Warning: cannot find \"" << name << "\" among input";
return -1;
}
/*!
......@@ -459,7 +459,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
CHECK(size == names.size())
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
uint32_t in_idx = GetInputIndex(names[i]);
int in_idx = GetInputIndex(names[i]);
CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i];
uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
CHECK_LT(eid, data_entry_.size());
LoadDLTensor(strm, &data_entry_[eid]);
......@@ -585,7 +586,8 @@ PackedFunc GraphRuntime::GetFunction(
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) {
this->SetInput(this->GetInputIndex(args[0]), args[1]);
int in_idx = this->GetInputIndex(args[0]);
if (in_idx >= 0) this->SetInput(in_idx, args[1]);
} else {
this->SetInput(args[0], args[1]);
}
......@@ -597,7 +599,9 @@ PackedFunc GraphRuntime::GetFunction(
} else if (name == "get_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) {
this->GetInput(this->GetInputIndex(args[0]), args[1]);
int in_idx = this->GetInputIndex(args[0]);
CHECK_GE(in_idx, 0);
this->GetInput(in_idx, args[1]);
} else {
this->GetInput(args[0], args[1]);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment