Commit 204c4442 by Tianqi Chen

[PASS] Add place device (#18)

parent cf02f5c9
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops\ export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -fPIC -Iinclude -Idmlc-core/include -fPIC
# specify tensor path # specify tensor path
......
...@@ -20,6 +20,13 @@ interface defintion and how operators are executed. ...@@ -20,6 +20,13 @@ interface defintion and how operators are executed.
NNVM is inspired by LLVM, aiming to be an intermediate representation library NNVM is inspired by LLVM, aiming to be an intermediate representation library
for neural nets and computation graphs generation and optimizations. for neural nets and computation graphs generation and optimizations.
## Why build deep learning system by parts
- Essential parts can be assembled in minimum way for embedding systems.
- Hackers can hack the parts they need and compose with other well defined parts.
- Decentralized modules enable new extensions creators to own their project
without creating a monothilic version.
## Deep learning system by parts ## Deep learning system by parts
This is one way to divide the deep learning system into common parts. This is one way to divide the deep learning system into common parts.
......
...@@ -71,14 +71,8 @@ class IndexedGraph { ...@@ -71,14 +71,8 @@ class IndexedGraph {
uint32_t node_id; uint32_t node_id;
/*! \brief index of output from the source. */ /*! \brief index of output from the source. */
uint32_t index; uint32_t index;
/*! /*! \brief version of the node */
* \brief compare equality uint32_t version;
* \param other the other entry to compare
* \return whether two entries equals to each other
*/
inline bool operator==(const NodeEntry& other) const {
return node_id == other.node_id && index == other.index;
}
}; };
/*! \brief Node data structure in IndexedGraph */ /*! \brief Node data structure in IndexedGraph */
struct Node { struct Node {
......
...@@ -45,7 +45,7 @@ using ShapeVector = std::vector<TShape>; ...@@ -45,7 +45,7 @@ using ShapeVector = std::vector<TShape>;
* *
* \code * \code
* Graph g = ApplyPass(src_graph, {"InferType"}); * Graph g = ApplyPass(src_graph, {"InferType"});
* const DTypeVector& types = g.GetAttr<ShapeVector>("dtype"); * const DTypeVector& types = g.GetAttr<DTypeVector>("dtype");
* // get shape by entry id * // get shape by entry id
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)]; * int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
* \endcode * \endcode
...@@ -54,6 +54,28 @@ using ShapeVector = std::vector<TShape>; ...@@ -54,6 +54,28 @@ using ShapeVector = std::vector<TShape>;
*/ */
using DTypeVector = std::vector<int>; using DTypeVector = std::vector<int>;
/*!
* \brief The result holder of device of each operator in the graph.
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice"
*
* \code
* Graph g = ApplyPass(src_graph, {"PlaceDevice"});
* const &device = g.GetAttr<DeviceVector>("dtype");
* // get device by node_id
* int device_type = device[g.indexed_graph().node_id(my_node)];
* \endcode
*/
using DeviceVector = std::vector<int>;
/*!
* \brief The result holder of device of each operator in the graph.
*
* \note Stored under graph.attrs["device_assign_map"], needed by Pass "PlaceDevice"
* -1 means unknown device
*/
using DeviceAssignMap = std::unordered_map<std::string, int>;
} // namespace nnvm } // namespace nnvm
#endif // NNVM_GRAPH_ATTR_TYPES_H_ #endif // NNVM_GRAPH_ATTR_TYPES_H_
...@@ -91,6 +91,24 @@ inline Graph InferType(Graph graph, ...@@ -91,6 +91,24 @@ inline Graph InferType(Graph graph,
return ApplyPass(std::move(graph), {"InferType"}); return ApplyPass(std::move(graph), {"InferType"});
} }
/*!
* \brief Place the devices
* \param graph source graph
* \param device_group_attr_key The attribute name for hinting the device group.
* \param device_assign_map The assignment map of device
* \param device_copy_op The name of copy op to be inserted when cross device copy happened.
* \return A graph with new attribute "device", cotaining device information of each node.
*/
inline Graph PlaceDevice(Graph graph,
std::string device_group_attr_key,
DeviceAssignMap device_assign_map,
std::string device_copy_op) {
graph.attrs["device_group_attr_key"] = std::make_shared<any>(std::move(device_group_attr_key));
graph.attrs["device_assign_map"] = std::make_shared<any>(std::move(device_assign_map));
graph.attrs["device_copy_op"] = std::make_shared<any>(std::move(device_copy_op));
return ApplyPass(std::move(graph), {"PlaceDevice"});
}
} // namespace pass } // namespace pass
} // namespace nnvm } // namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_ #endif // NNVM_PASS_FUNCTIONS_H_
...@@ -40,7 +40,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { ...@@ -40,7 +40,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
for (const auto& e : n->inputs) { for (const auto& e : n->inputs) {
auto it = node2index_.find(e.node.get()); auto it = node2index_.find(e.node.get());
CHECK(it != node2index_.end() && it->first == e.node.get()); CHECK(it != node2index_.end() && it->first == e.node.get());
input_entries_.emplace_back(NodeEntry{it->second, e.index}); input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version});
} }
inputs_rptr.push_back(input_entries_.size()); inputs_rptr.push_back(input_entries_.size());
// control deps // control deps
......
...@@ -94,6 +94,11 @@ NNVM_REGISTER_OP(exp) ...@@ -94,6 +94,11 @@ NNVM_REGISTER_OP(exp)
.attr("inplace_pair", std::make_pair(0, 0)) .attr("inplace_pair", std::make_pair(0, 0))
.attr<FInferShape>("FInferShape", SameShape); .attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(cross_device_copy)
.describe("Copy data across device.")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(conv2d) NNVM_REGISTER_OP(conv2d)
.describe("take conv of input") .describe("take conv of input")
......
...@@ -35,10 +35,14 @@ Graph InferAttr(Graph &&ret, ...@@ -35,10 +35,14 @@ Graph InferAttr(Graph &&ret,
for (size_t i = 0; i < shape_args.size(); ++i) { for (size_t i = 0; i < shape_args.size(); ++i) {
rshape[idx.entry_id(idx.arg_nodes()[i], 0)] = shape_args[i]; rshape[idx.entry_id(idx.arg_nodes()[i], 0)] = shape_args[i];
} }
// erase the provided arguments
ret.attrs.erase(arg_name);
} }
std::string shape_attr_key; std::string shape_attr_key;
if (ret.attrs.count(attr_key_name) != 0) { if (ret.attrs.count(attr_key_name) != 0) {
shape_attr_key = ret.GetAttr<std::string>(attr_key_name); shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
// erase the provided arguments
ret.attrs.erase(attr_key_name);
} }
// temp space for shape inference. // temp space for shape inference.
......
/*!
* Copyright (c) 2016 by Contributors
* \file place_device.cc
* \brief Inference the device of each operator given known information.
* Insert a copy node automatically when there is a cross device.
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
namespace nnvm {
namespace pass {
// simply logic to place device according to device_group hint
// insert copy node when there is
Graph PlaceDevice(Graph src) {
CHECK_NE(src.attrs.count("device_group_attr_key"), 0)
<< "Need graph attribute \"device_group_attr_key\" in PlaceDevice";
CHECK_NE(src.attrs.count("device_assign_map"), 0)
<< "Need graph attribute \"device_assign_map\" in PlaceDevice";
CHECK_NE(src.attrs.count("device_copy_op"), 0)
<< "Need graph attribute \"device_copy_op\" in PlaceDevice";
std::string device_group_attr_key = src.GetAttr<std::string>("device_group_attr_key");
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
const IndexedGraph& idx = src.indexed_graph();
DeviceVector device(idx.num_nodes(), -1);
// forward pass
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
auto it = inode.source->attrs.dict.find(device_group_attr_key);
if (it != inode.source->attrs.dict.end()) {
const std::string& device_group = it->second;
auto dit = device_assign_map.find(device_group);
CHECK_NE(dit, device_assign_map.end())
<< "The device assignment not found for group " << device_group;
device[nid] = dit->second;
} else {
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (device[e.node_id] != -1) {
device[nid] = device[e.node_id]; break;
}
}
}
}
// backward pass
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
uint32_t nid = i - 1;
const auto& inode = idx[nid];
if (device[nid] == -1) continue;
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (device[e.node_id] == -1) device[e.node_id] = device[nid];
}
}
int num_dev = 1, other_dev_id = -1;
for (int& dev : device) {
if (dev == -1) dev = 0;
if (dev != other_dev_id) {
if (other_dev_id != -1) ++num_dev;
other_dev_id = dev;
}
}
if (num_dev == 1) {
src.attrs.erase("device_group_attr_key");
src.attrs.erase("device_assign_map");
src.attrs.erase("device_copy_op");
src.attrs["device"] = std::make_shared<any>(std::move(device));
return src;
}
std::map<std::tuple<uint32_t, uint32_t, int>, NodePtr> copy_map;
std::vector<NodePtr> new_node_map(idx.num_nodes(), nullptr);
std::unordered_map<const Node*, int> new_device_map;
// insert copy node
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
int dev_id = device[nid];
const auto& inode = idx[nid];
// check if mutation is needed
bool need_mutate = false;
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
need_mutate = true; break;
}
}
if (!need_mutate) {
for (const uint32_t cid : inode.control_deps) {
if (new_node_map[cid] != nullptr) {
need_mutate = true; break;
}
}
}
if (need_mutate) {
NodePtr new_node = Node::Create();
new_node->attrs = inode.source->attrs;
new_node->inputs.reserve(inode.inputs.size());
for (size_t i = 0; i < inode.inputs.size(); ++i) {
const IndexedGraph::NodeEntry& e = inode.inputs[i];
if (dev_id != device[e.node_id]) {
auto copy_key = std::make_tuple(e.node_id, e.index, dev_id);
auto it = copy_map.find(copy_key);
if (it != copy_map.end() && it->first == copy_key) {
new_node->inputs.emplace_back(
NodeEntry{it->second, 0, 0});
} else {
NodePtr copy_node = Node::Create();
copy_node->op = copy_op;
std::ostringstream os;
os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy";
copy_node->attrs.name = os.str();
copy_node->inputs.push_back(inode.source->inputs[i]);
copy_map[copy_key] = copy_node;
new_device_map[copy_node.get()] = dev_id;
new_node->inputs.emplace_back(
NodeEntry{std::move(copy_node), 0, 0});
}
} else {
if (new_node_map[e.node_id] != nullptr) {
new_node->inputs.emplace_back(
NodeEntry{new_node_map[e.node_id], e.index, 0});
} else {
new_node->inputs.push_back(inode.source->inputs[i]);
}
}
}
new_node->control_deps.reserve(inode.control_deps.size());
for (size_t i = 0; i < inode.control_deps.size(); ++i) {
uint32_t cid = inode.control_deps[i];
if (new_node_map[cid] != nullptr) {
new_node->control_deps.push_back(new_node_map[cid]);
} else {
new_node->control_deps.push_back(inode.source->control_deps[i]);
}
}
new_device_map[new_node.get()] = dev_id;
new_node_map[nid] = std::move(new_node);
} else {
new_device_map[inode.source] = dev_id;
}
}
// make the new graph
Graph ret;
for (const NodeEntry& e : src.outputs) {
if (new_node_map[idx.node_id(e.node.get())] != nullptr) {
ret.outputs.emplace_back(
NodeEntry{new_node_map[idx.node_id(e.node.get())], e.index, e.version});
} else {
ret.outputs.emplace_back(e);
}
}
DeviceVector new_device_vec(ret.indexed_graph().num_nodes());
for (uint32_t nid = 0; nid < ret.indexed_graph().num_nodes(); ++nid) {
if (new_device_map.count(ret.indexed_graph()[nid].source) == 0) {
LOG(INFO) << "canot find " << ret.indexed_graph()[nid].source->attrs.name;
}
new_device_vec[nid] = new_device_map.at(ret.indexed_graph()[nid].source);
}
ret.attrs["device"] = std::make_shared<any>(std::move(new_device_vec));
return ret;
}
NNVM_REGISTER_PASS(PlaceDevice)
.describe("Infer the device type of each operator."\
"Insert a copy node when there is cross device copy")
.set_body(PlaceDevice)
.set_change_graph(true)
.provide_graph_attr("device")
.depend_graph_attr("device_group_attr_key")
.depend_graph_attr("device_assign_map")
.depend_graph_attr("device_copy_op");
DMLC_JSON_ENABLE_ANY(DeviceAssignMap, dict_str_int);
} // namespace pass
} // namespace nnvm
...@@ -68,14 +68,18 @@ struct JSONNode { ...@@ -68,14 +68,18 @@ struct JSONNode {
writer->BeginObject(); writer->BeginObject();
if (node->op != nullptr) { if (node->op != nullptr) {
writer->WriteObjectKeyValue("op", node->op->name); writer->WriteObjectKeyValue("op", node->op->name);
writer->WriteObjectKeyValue("attr", node->attrs.dict);
} else { } else {
std::string json_null = "null"; std::string json_null = "null";
writer->WriteObjectKeyValue("op", json_null); writer->WriteObjectKeyValue("op", json_null);
} }
writer->WriteObjectKeyValue("name", node->attrs.name); writer->WriteObjectKeyValue("name", node->attrs.name);
if (node->attrs.dict.size() != 0) {
writer->WriteObjectKeyValue("attr", node->attrs.dict);
}
writer->WriteObjectKeyValue("inputs", inputs); writer->WriteObjectKeyValue("inputs", inputs);
if (control_deps.size() != 0) {
writer->WriteObjectKeyValue("control_deps", control_deps); writer->WriteObjectKeyValue("control_deps", control_deps);
}
writer->EndObject(); writer->EndObject();
} }
......
...@@ -76,6 +76,25 @@ def test_infer_type(): ...@@ -76,6 +76,25 @@ def test_infer_type():
assert g.json_attr('dtype')[jnode_row_ptr[nindex["cast1"]]] == 1 assert g.json_attr('dtype')[jnode_row_ptr[nindex["cast1"]]] == 1
assert g.json_attr('dtype')[jnode_row_ptr[nindex["add1"]]] == 0 assert g.json_attr('dtype')[jnode_row_ptr[nindex["add1"]]] == 0
def test_place_device():
x = sym.Variable('x', device_group="stage1")
y = sym.add(x, x, name='add1')
y = sym.cast(y, dtype=1, name="cast1")
z = sym.add(y, y, device_group="stage2", name="add2")
z = sym.add(z, sym.exp(y, device_group="stage2"), name="add3")
g = graph.create(z)
g._set_json_attr("device_group_attr_key", "device_group")
g._set_json_attr("device_assign_map", {"stage1": 0, "stage2" : 1}, "dict_str_int")
g._set_json_attr("device_copy_op", "cross_device_copy")
g = g.apply("PlaceDevice")
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert g.json_attr('device')[jnode_row_ptr[nindex["add2"]]] == 1
assert g.json_attr('device')[jnode_row_ptr[nindex["add3"]]] == 1
assert g.json_attr('device')[jnode_row_ptr[nindex["cast1"]]] == 0
if __name__ == "__main__": if __name__ == "__main__":
test_order_mutation_pass() test_order_mutation_pass()
...@@ -83,3 +102,4 @@ if __name__ == "__main__": ...@@ -83,3 +102,4 @@ if __name__ == "__main__":
test_json_pass() test_json_pass()
test_infer_shape() test_infer_shape()
test_infer_type() test_infer_type()
test_place_device()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment