Commit e4820d34 by Eric Junyuan Xie Committed by Tianqi Chen

Fix json parsing behavior (#83)

parent 98fd6bd0
...@@ -63,7 +63,9 @@ enum ParamInitOption { ...@@ -63,7 +63,9 @@ enum ParamInitOption {
/*! \brief allow unknown parameters */ /*! \brief allow unknown parameters */
kAllowUnknown, kAllowUnknown,
/*! \brief need to match exact parameters */ /*! \brief need to match exact parameters */
kAllMatch kAllMatch,
/*! \brief allow unmatched hidden field with format __*__ */
kAllowHidden
}; };
} // namespace parameter } // namespace parameter
/*! /*!
...@@ -122,11 +124,11 @@ struct Parameter { ...@@ -122,11 +124,11 @@ struct Parameter {
*/ */
template<typename Container> template<typename Container>
inline void Init(const Container &kwargs, inline void Init(const Container &kwargs,
parameter::ParamInitOption option = parameter::kAllowUnknown) { parameter::ParamInitOption option = parameter::kAllowHidden) {
PType::__MANAGER__()->RunInit(static_cast<PType*>(this), PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(), kwargs.begin(), kwargs.end(),
NULL, NULL,
option == parameter::kAllowUnknown); option);
} }
/*! /*!
* \brief initialize the parameter by keyword arguments. * \brief initialize the parameter by keyword arguments.
...@@ -143,7 +145,7 @@ struct Parameter { ...@@ -143,7 +145,7 @@ struct Parameter {
std::vector<std::pair<std::string, std::string> > unknown; std::vector<std::pair<std::string, std::string> > unknown;
PType::__MANAGER__()->RunInit(static_cast<PType*>(this), PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(), kwargs.begin(), kwargs.end(),
&unknown, true); &unknown, parameter::kAllowUnknown);
return unknown; return unknown;
} }
/*! /*!
...@@ -369,7 +371,7 @@ class ParamManager { ...@@ -369,7 +371,7 @@ class ParamManager {
RandomAccessIterator begin, RandomAccessIterator begin,
RandomAccessIterator end, RandomAccessIterator end,
std::vector<std::pair<std::string, std::string> > *unknown_args, std::vector<std::pair<std::string, std::string> > *unknown_args,
bool allow_unknown) const { parameter::ParamInitOption option) const {
std::set<FieldAccessEntry*> selected_args; std::set<FieldAccessEntry*> selected_args;
for (RandomAccessIterator it = begin; it != end; ++it) { for (RandomAccessIterator it = begin; it != end; ++it) {
FieldAccessEntry *e = Find(it->first); FieldAccessEntry *e = Find(it->first);
...@@ -381,7 +383,13 @@ class ParamManager { ...@@ -381,7 +383,13 @@ class ParamManager {
if (unknown_args != NULL) { if (unknown_args != NULL) {
unknown_args->push_back(*it); unknown_args->push_back(*it);
} else { } else {
if (!allow_unknown) { if (option != parameter::kAllowUnknown) {
if (option == parameter::kAllowHidden &&
it->first.length() > 4 &&
it->first.find("__") == 0 &&
it->first.rfind("__") == it->first.length()-2) {
continue;
}
std::ostringstream os; std::ostringstream os;
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n"; os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
os << "----------------\n"; os << "----------------\n";
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace nnvm { namespace nnvm {
namespace symbol_constants { namespace symbol_constants {
const char *kNamespaceSeparator = "_"; const char *kNamespaceSeparator = "$";
} // namespace symbol_constants } // namespace symbol_constants
// auxililary version attribute in variable. // auxililary version attribute in variable.
......
...@@ -109,10 +109,6 @@ struct JSONNode { ...@@ -109,10 +109,6 @@ struct JSONNode {
if (op_type_str != "null") { if (op_type_str != "null") {
try { try {
node->attrs.op = Op::Get(op_type_str); node->attrs.op = Op::Get(op_type_str);
// rebuild attribute parser
if (node->op()->attr_parser != nullptr) {
node->op()->attr_parser(&(node->attrs));
}
} catch (const dmlc::Error &err) { } catch (const dmlc::Error &err) {
std::ostringstream os; std::ostringstream os;
os << "Failed loading Op " << node->attrs.name os << "Failed loading Op " << node->attrs.name
...@@ -163,6 +159,10 @@ Graph LoadJSON(Graph src) { ...@@ -163,6 +159,10 @@ Graph LoadJSON(Graph src) {
<< "Load JSON require json to be presented."; << "Load JSON require json to be presented.";
const std::string &json_str = const std::string &json_str =
nnvm::get<std::string>(*src.attrs.at("json")); nnvm::get<std::string>(*src.attrs.at("json"));
bool no_parse = false;
if (src.attrs.count("load_json_no_parse")) {
no_parse = nnvm::get<bool>(*src.attrs.at("load_json_no_parse"));
}
std::istringstream is(json_str); std::istringstream is(json_str);
dmlc::JSONReader reader(&is); dmlc::JSONReader reader(&is);
JSONGraph jgraph; JSONGraph jgraph;
...@@ -179,6 +179,11 @@ Graph LoadJSON(Graph src) { ...@@ -179,6 +179,11 @@ Graph LoadJSON(Graph src) {
for (uint32_t nid : n.control_deps) { for (uint32_t nid : n.control_deps) {
n.node->control_deps.push_back(jgraph.nodes[nid].node); n.node->control_deps.push_back(jgraph.nodes[nid].node);
} }
// rebuild attribute parser
if (!no_parse && n.node->op() != nullptr &&
n.node->op()->attr_parser != nullptr) {
n.node->op()->attr_parser(&(n.node->attrs));
}
} }
// consistent check // consistent check
for (uint32_t nid : jgraph.arg_nodes) { for (uint32_t nid : jgraph.arg_nodes) {
......
...@@ -12,7 +12,7 @@ def test_compose(): ...@@ -12,7 +12,7 @@ def test_compose():
assert y.list_attr()['gpu'] == '1' assert y.list_attr()['gpu'] == '1'
z = y.get_internals() z = y.get_internals()
assert z['add_output'].list_output_names() == ['add_output'] assert z['add_output'].list_output_names() == ['add_output']
assert y.list_attr(recursive=True)['add_gpu'] == '2' assert y.list_attr(recursive=True)['add$gpu'] == '2'
def test_default_input(): def test_default_input():
x = sym.Variable('x') x = sym.Variable('x')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment