Commit a6f6a0e0 by Tianqi Chen

Place device now compatible and tested (#33)

parent 9a4e1339
...@@ -259,7 +259,8 @@ inline const std::type_info& any::type() const { ...@@ -259,7 +259,8 @@ inline const std::type_info& any::type() const {
template<typename T> template<typename T>
inline void any::check_type() const { inline void any::check_type() const {
CHECK(type_ != nullptr) CHECK(type_ != nullptr)
<< "The any container is empty"; << "The any container is empty"
<< " requested=" << typeid(T).name();
CHECK(type_->ptype_info == &typeid(T)) CHECK(type_->ptype_info == &typeid(T))
<< "The stored type mismatch" << "The stored type mismatch"
<< " stored=" << type_->ptype_info->name() << " stored=" << type_->ptype_info->name()
......
...@@ -57,6 +57,14 @@ class FieldEntry; ...@@ -57,6 +57,14 @@ class FieldEntry;
// forward declare ParamManagerSingleton // forward declare ParamManagerSingleton
template<typename PType> template<typename PType>
struct ParamManagerSingleton; struct ParamManagerSingleton;
/*! \brief option in parameter initialization */
enum ParamInitOption {
/*! \brief allow unknown parameters */
kAllowUnknown,
/*! \brief need to match exact parameters */
kAllMatch
};
} // namespace parameter } // namespace parameter
/*! /*!
* \brief Information about a parameter field in string representations. * \brief Information about a parameter field in string representations.
...@@ -108,13 +116,17 @@ struct Parameter { ...@@ -108,13 +116,17 @@ struct Parameter {
* and throw error if something wrong happens. * and throw error if something wrong happens.
* *
* \param kwargs map of keyword arguments, or vector of pairs * \param kwargs map of keyword arguments, or vector of pairs
* \parma option The option on initialization.
* \tparam Container container type * \tparam Container container type
* \throw ParamError when something go wrong. * \throw ParamError when something go wrong.
*/ */
template<typename Container> template<typename Container>
inline void Init(const Container &kwargs) { inline void Init(const Container &kwargs,
parameter::ParamInitOption option = parameter::kAllowUnknown) {
PType::__MANAGER__()->RunInit(static_cast<PType*>(this), PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(), NULL); kwargs.begin(), kwargs.end(),
NULL,
option == parameter::kAllowUnknown);
} }
/*! /*!
* \brief initialize the parameter by keyword arguments. * \brief initialize the parameter by keyword arguments.
...@@ -130,7 +142,8 @@ struct Parameter { ...@@ -130,7 +142,8 @@ struct Parameter {
InitAllowUnknown(const Container &kwargs) { InitAllowUnknown(const Container &kwargs) {
std::vector<std::pair<std::string, std::string> > unknown; std::vector<std::pair<std::string, std::string> > unknown;
PType::__MANAGER__()->RunInit(static_cast<PType*>(this), PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(), &unknown); kwargs.begin(), kwargs.end(),
&unknown, true);
return unknown; return unknown;
} }
/*! /*!
...@@ -355,7 +368,8 @@ class ParamManager { ...@@ -355,7 +368,8 @@ class ParamManager {
inline void RunInit(void *head, inline void RunInit(void *head,
RandomAccessIterator begin, RandomAccessIterator begin,
RandomAccessIterator end, RandomAccessIterator end,
std::vector<std::pair<std::string, std::string> > *unknown_args) const { std::vector<std::pair<std::string, std::string> > *unknown_args,
bool allow_unknown) const {
std::set<FieldAccessEntry*> selected_args; std::set<FieldAccessEntry*> selected_args;
for (RandomAccessIterator it = begin; it != end; ++it) { for (RandomAccessIterator it = begin; it != end; ++it) {
FieldAccessEntry *e = Find(it->first); FieldAccessEntry *e = Find(it->first);
...@@ -367,11 +381,13 @@ class ParamManager { ...@@ -367,11 +381,13 @@ class ParamManager {
if (unknown_args != NULL) { if (unknown_args != NULL) {
unknown_args->push_back(*it); unknown_args->push_back(*it);
} else { } else {
std::ostringstream os; if (!allow_unknown) {
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n"; std::ostringstream os;
os << "----------------\n"; os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
PrintDocString(os); os << "----------------\n";
throw dmlc::ParamError(os.str()); PrintDocString(os);
throw dmlc::ParamError(os.str());
}
} }
} }
} }
......
...@@ -25,7 +25,6 @@ Graph PlaceDevice(Graph src) { ...@@ -25,7 +25,6 @@ Graph PlaceDevice(Graph src) {
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op")); const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map"); auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
const IndexedGraph& idx = src.indexed_graph(); const IndexedGraph& idx = src.indexed_graph();
DeviceVector device; DeviceVector device;
// copy on write semanatics // copy on write semanatics
if (src.attrs.count("device") != 0) { if (src.attrs.count("device") != 0) {
...@@ -79,10 +78,10 @@ Graph PlaceDevice(Graph src) { ...@@ -79,10 +78,10 @@ Graph PlaceDevice(Graph src) {
src.attrs["device"] = std::make_shared<any>(std::move(device)); src.attrs["device"] = std::make_shared<any>(std::move(device));
return src; return src;
} }
std::map<std::tuple<uint32_t, uint32_t, int>, NodePtr> copy_map; std::map<std::tuple<uint32_t, uint32_t, int>, NodePtr> copy_map;
std::vector<NodePtr> new_node_map(idx.num_nodes(), nullptr); std::vector<NodePtr> new_node_map(idx.num_nodes(), nullptr);
std::unordered_map<const Node*, int> new_device_map; std::unordered_map<const Node*, int> new_device_map;
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
// insert copy node // insert copy node
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
...@@ -90,6 +89,16 @@ Graph PlaceDevice(Graph src) { ...@@ -90,6 +89,16 @@ Graph PlaceDevice(Graph src) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
// check if mutation is needed // check if mutation is needed
bool need_mutate = false; bool need_mutate = false;
if (!inode.source->is_variable() && fmutate_inputs.count(inode.source->op())) {
for (uint32_t index : fmutate_inputs[inode.source->op()](inode.source->attrs)) {
auto e = inode.inputs[index];
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
LOG(FATAL) << " mutable state cannot go across device"
<< " op=" << inode.source->op()->name
<< " input_state_index=" << index;
}
}
}
for (const IndexedGraph::NodeEntry& e : inode.inputs) { for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) { if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
need_mutate = true; break; need_mutate = true; break;
...@@ -102,6 +111,9 @@ Graph PlaceDevice(Graph src) { ...@@ -102,6 +111,9 @@ Graph PlaceDevice(Graph src) {
} }
} }
} }
if (inode.source->is_variable()) {
CHECK(!need_mutate) << "consistency check";
}
if (need_mutate) { if (need_mutate) {
NodePtr new_node = Node::Create(); NodePtr new_node = Node::Create();
new_node->attrs = inode.source->attrs; new_node->attrs = inode.source->attrs;
...@@ -120,7 +132,15 @@ Graph PlaceDevice(Graph src) { ...@@ -120,7 +132,15 @@ Graph PlaceDevice(Graph src) {
os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy"; os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy";
copy_node->attrs.op = copy_op; copy_node->attrs.op = copy_op;
copy_node->attrs.name = os.str(); copy_node->attrs.name = os.str();
copy_node->inputs.push_back(inode.source->inputs[i]); if (new_node_map[e.node_id] != nullptr) {
copy_node->inputs.emplace_back(
NodeEntry{new_node_map[e.node_id], e.index, 0});
} else {
copy_node->inputs.push_back(inode.source->inputs[i]);
}
if (copy_node->attrs.op->attr_parser != nullptr) {
copy_node->attrs.op->attr_parser(&(copy_node->attrs));
}
copy_map[copy_key] = copy_node; copy_map[copy_key] = copy_node;
new_device_map[copy_node.get()] = dev_id; new_device_map[copy_node.get()] = dev_id;
new_node->inputs.emplace_back( new_node->inputs.emplace_back(
...@@ -130,7 +150,7 @@ Graph PlaceDevice(Graph src) { ...@@ -130,7 +150,7 @@ Graph PlaceDevice(Graph src) {
if (new_node_map[e.node_id] != nullptr) { if (new_node_map[e.node_id] != nullptr) {
new_node->inputs.emplace_back( new_node->inputs.emplace_back(
NodeEntry{new_node_map[e.node_id], e.index, 0}); NodeEntry{new_node_map[e.node_id], e.index, 0});
} else { } else {
new_node->inputs.push_back(inode.source->inputs[i]); new_node->inputs.push_back(inode.source->inputs[i]);
} }
} }
...@@ -150,7 +170,6 @@ Graph PlaceDevice(Graph src) { ...@@ -150,7 +170,6 @@ Graph PlaceDevice(Graph src) {
new_device_map[inode.source] = dev_id; new_device_map[inode.source] = dev_id;
} }
} }
// make the new graph // make the new graph
Graph ret; Graph ret;
for (const NodeEntry& e : src.outputs) { for (const NodeEntry& e : src.outputs) {
...@@ -163,10 +182,11 @@ Graph PlaceDevice(Graph src) { ...@@ -163,10 +182,11 @@ Graph PlaceDevice(Graph src) {
} }
DeviceVector new_device_vec(ret.indexed_graph().num_nodes()); DeviceVector new_device_vec(ret.indexed_graph().num_nodes());
for (uint32_t nid = 0; nid < ret.indexed_graph().num_nodes(); ++nid) { for (uint32_t nid = 0; nid < ret.indexed_graph().num_nodes(); ++nid) {
if (new_device_map.count(ret.indexed_graph()[nid].source) == 0) { auto source = ret.indexed_graph()[nid].source;
LOG(INFO) << "canot find " << ret.indexed_graph()[nid].source->attrs.name; if (new_device_map.count(source) == 0) {
LOG(FATAL) << "canot find " << source;
} }
new_device_vec[nid] = new_device_map.at(ret.indexed_graph()[nid].source); new_device_vec[nid] = new_device_map.at(source);
} }
ret.attrs["device"] = std::make_shared<any>(std::move(new_device_vec)); ret.attrs["device"] = std::make_shared<any>(std::move(new_device_vec));
return ret; return ret;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment