Commit 4363d0fb by Tianqi Chen

[PASS] Make placedevice compatible with backward op (#88)

parent 996ff839
......@@ -25,6 +25,8 @@ Graph PlaceDevice(Graph src) {
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
const IndexedGraph& idx = src.indexed_graph();
static auto& is_backward =
Op::GetAttr<TIsBackward>("TIsBackward");
DeviceVector device;
// copy on write semanatics
if (src.attrs.count("device") != 0) {
......@@ -45,9 +47,16 @@ Graph PlaceDevice(Graph src) {
<< "The device assignment not found for group " << device_group;
device[nid] = dit->second;
} else {
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (device[e.node_id] != -1) {
device[nid] = device[e.node_id]; break;
if (!inode.source->is_variable() &&
is_backward.get(inode.source->op(), false)) {
if (device[inode.control_deps[0]] != -1) {
device[nid] = device[inode.control_deps[0]];
}
} else {
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (device[e.node_id] != -1) {
device[nid] = device[e.node_id]; break;
}
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment