Commit 36ea5392 by nhynes Committed by Tianqi Chen

Only warn when unable to find a graph input (#1052)

parent cc7a8fcf
...@@ -88,7 +88,7 @@ class GraphRuntime : public ModuleNode { ...@@ -88,7 +88,7 @@ class GraphRuntime : public ModuleNode {
return static_cast<int>(i); return static_cast<int>(i);
} }
} }
LOG(FATAL) << "cannot find " << name << " among input"; LOG(WARNING) << "Warning: cannot find \"" << name << "\" among input";
return -1; return -1;
} }
/*! /*!
...@@ -459,7 +459,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { ...@@ -459,7 +459,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
CHECK(size == names.size()) CHECK(size == names.size())
<< "Invalid parameters file format"; << "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
uint32_t in_idx = GetInputIndex(names[i]); int in_idx = GetInputIndex(names[i]);
CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i];
uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
CHECK_LT(eid, data_entry_.size()); CHECK_LT(eid, data_entry_.size());
LoadDLTensor(strm, &data_entry_[eid]); LoadDLTensor(strm, &data_entry_[eid]);
...@@ -585,7 +586,8 @@ PackedFunc GraphRuntime::GetFunction( ...@@ -585,7 +586,8 @@ PackedFunc GraphRuntime::GetFunction(
if (name == "set_input") { if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) { if (args[0].type_code() == kStr) {
this->SetInput(this->GetInputIndex(args[0]), args[1]); int in_idx = this->GetInputIndex(args[0]);
if (in_idx >= 0) this->SetInput(in_idx, args[1]);
} else { } else {
this->SetInput(args[0], args[1]); this->SetInput(args[0], args[1]);
} }
...@@ -597,7 +599,9 @@ PackedFunc GraphRuntime::GetFunction( ...@@ -597,7 +599,9 @@ PackedFunc GraphRuntime::GetFunction(
} else if (name == "get_input") { } else if (name == "get_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) { if (args[0].type_code() == kStr) {
this->GetInput(this->GetInputIndex(args[0]), args[1]); int in_idx = this->GetInputIndex(args[0]);
CHECK_GE(in_idx, 0);
this->GetInput(in_idx, args[1]);
} else { } else {
this->GetInput(args[0], args[1]); this->GetInput(args[0], args[1]);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment