Commit e4820d34 by Eric Junyuan Xie Committed by Tianqi Chen

Fix json parsing behavior (#83)

parent 98fd6bd0
......@@ -63,7 +63,9 @@ enum ParamInitOption {
/*! \brief allow unknown parameters */
kAllowUnknown,
/*! \brief need to match exact parameters */
kAllMatch
kAllMatch,
/*! \brief allow unmatched hidden field with format __*__ */
kAllowHidden
};
} // namespace parameter
/*!
......@@ -122,11 +124,11 @@ struct Parameter {
*/
template<typename Container>
inline void Init(const Container &kwargs,
parameter::ParamInitOption option = parameter::kAllowUnknown) {
parameter::ParamInitOption option = parameter::kAllowHidden) {
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(),
NULL,
option == parameter::kAllowUnknown);
option);
}
/*!
* \brief initialize the parameter by keyword arguments.
......@@ -143,7 +145,7 @@ struct Parameter {
std::vector<std::pair<std::string, std::string> > unknown;
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(),
&unknown, true);
&unknown, parameter::kAllowUnknown);
return unknown;
}
/*!
......@@ -369,7 +371,7 @@ class ParamManager {
RandomAccessIterator begin,
RandomAccessIterator end,
std::vector<std::pair<std::string, std::string> > *unknown_args,
bool allow_unknown) const {
parameter::ParamInitOption option) const {
std::set<FieldAccessEntry*> selected_args;
for (RandomAccessIterator it = begin; it != end; ++it) {
FieldAccessEntry *e = Find(it->first);
......@@ -381,7 +383,13 @@ class ParamManager {
if (unknown_args != NULL) {
unknown_args->push_back(*it);
} else {
if (!allow_unknown) {
if (option != parameter::kAllowUnknown) {
if (option == parameter::kAllowHidden &&
it->first.length() > 4 &&
it->first.find("__") == 0 &&
it->first.rfind("__") == it->first.length()-2) {
continue;
}
std::ostringstream os;
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
os << "----------------\n";
......
......@@ -10,7 +10,7 @@
namespace nnvm {
namespace symbol_constants {
const char *kNamespaceSeparator = "_";
const char *kNamespaceSeparator = "$";
} // namespace symbol_constants
// auxililary version attribute in variable.
......
......@@ -109,10 +109,6 @@ struct JSONNode {
if (op_type_str != "null") {
try {
node->attrs.op = Op::Get(op_type_str);
// rebuild attribute parser
if (node->op()->attr_parser != nullptr) {
node->op()->attr_parser(&(node->attrs));
}
} catch (const dmlc::Error &err) {
std::ostringstream os;
os << "Failed loading Op " << node->attrs.name
......@@ -163,6 +159,10 @@ Graph LoadJSON(Graph src) {
<< "Load JSON require json to be presented.";
const std::string &json_str =
nnvm::get<std::string>(*src.attrs.at("json"));
bool no_parse = false;
if (src.attrs.count("load_json_no_parse")) {
no_parse = nnvm::get<bool>(*src.attrs.at("load_json_no_parse"));
}
std::istringstream is(json_str);
dmlc::JSONReader reader(&is);
JSONGraph jgraph;
......@@ -179,6 +179,11 @@ Graph LoadJSON(Graph src) {
for (uint32_t nid : n.control_deps) {
n.node->control_deps.push_back(jgraph.nodes[nid].node);
}
// rebuild attribute parser
if (!no_parse && n.node->op() != nullptr &&
n.node->op()->attr_parser != nullptr) {
n.node->op()->attr_parser(&(n.node->attrs));
}
}
// consistent check
for (uint32_t nid : jgraph.arg_nodes) {
......
......@@ -12,7 +12,7 @@ def test_compose():
assert y.list_attr()['gpu'] == '1'
z = y.get_internals()
assert z['add_output'].list_output_names() == ['add_output']
assert y.list_attr(recursive=True)['add_gpu'] == '2'
assert y.list_attr(recursive=True)['add$gpu'] == '2'
def test_default_input():
x = sym.Variable('x')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment