Commit 4363d0fb by Tianqi Chen

[PASS] Make placedevice compatible with backward op (#88)

parent 996ff839
...@@ -25,6 +25,8 @@ Graph PlaceDevice(Graph src) { ...@@ -25,6 +25,8 @@ Graph PlaceDevice(Graph src) {
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op")); const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map"); auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
const IndexedGraph& idx = src.indexed_graph(); const IndexedGraph& idx = src.indexed_graph();
static auto& is_backward =
Op::GetAttr<TIsBackward>("TIsBackward");
DeviceVector device; DeviceVector device;
// copy on write semanatics // copy on write semanatics
if (src.attrs.count("device") != 0) { if (src.attrs.count("device") != 0) {
...@@ -45,9 +47,16 @@ Graph PlaceDevice(Graph src) { ...@@ -45,9 +47,16 @@ Graph PlaceDevice(Graph src) {
<< "The device assignment not found for group " << device_group; << "The device assignment not found for group " << device_group;
device[nid] = dit->second; device[nid] = dit->second;
} else { } else {
for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (!inode.source->is_variable() &&
if (device[e.node_id] != -1) { is_backward.get(inode.source->op(), false)) {
device[nid] = device[e.node_id]; break; if (device[inode.control_deps[0]] != -1) {
device[nid] = device[inode.control_deps[0]];
}
} else {
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (device[e.node_id] != -1) {
device[nid] = device[e.node_id]; break;
}
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment