Commit cc7cbbe7 by Lianmin Zheng Committed by Tianqi Chen

[BUILD] add target_host to compiler.build (#240)

parent 2b15684f
...@@ -48,6 +48,12 @@ class Graph { ...@@ -48,6 +48,12 @@ class Graph {
template<typename T> template<typename T>
inline const T& GetAttr(const std::string& attr_name) const; inline const T& GetAttr(const std::string& attr_name) const;
/*! /*!
* \brief Check whether has a specific attribute.
* \param attr_name the name of the attribute
* \return a boolean result
*/
inline bool HasAttr(const std::string& attr_name) const;
/*!
* \brief Get a move copy of the attribute, implement copy on write semantics. * \brief Get a move copy of the attribute, implement copy on write semantics.
* The content is moved if the reference counter of shared_ptr is 1. * The content is moved if the reference counter of shared_ptr is 1.
* The attribute is erased from attrs after the call. * The attribute is erased from attrs after the call.
...@@ -226,6 +232,11 @@ inline const T& Graph::GetAttr(const std::string& attr_name) const { ...@@ -226,6 +232,11 @@ inline const T& Graph::GetAttr(const std::string& attr_name) const {
return nnvm::get<T>(*it->second); return nnvm::get<T>(*it->second);
} }
inline bool Graph::HasAttr(const std::string& attr_name) const {
auto it = attrs.find(attr_name);
return it != attrs.end();
}
template<typename T> template<typename T>
inline T Graph::MoveCopyAttr(const std::string& attr_name) { inline T Graph::MoveCopyAttr(const std::string& attr_name) {
auto it = attrs.find(attr_name); auto it = attrs.find(attr_name);
......
...@@ -112,8 +112,10 @@ def _lower(sch, inputs, func_name, graph): ...@@ -112,8 +112,10 @@ def _lower(sch, inputs, func_name, graph):
@tvm.register_func("nnvm.compiler.build_target") @tvm.register_func("nnvm.compiler.build_target")
def _build(funcs, target): def _build(funcs, target, target_host):
return tvm.build(funcs, target=target) if target_host == "":
target_host = None
return tvm.build(funcs, target=target, target_host=target_host)
def _update_shape_dtype(shape, dtype, params): def _update_shape_dtype(shape, dtype, params):
...@@ -161,7 +163,7 @@ def optimize(graph, shape, dtype="float32"): ...@@ -161,7 +163,7 @@ def optimize(graph, shape, dtype="float32"):
return graph return graph
def build(graph, target=None, shape=None, dtype="float32", params=None): def build(graph, target=None, shape=None, dtype="float32", params=None, target_host=None):
"""Build graph into runtime library. """Build graph into runtime library.
The build function will optimize the graph and do the compilation. The build function will optimize the graph and do the compilation.
...@@ -189,6 +191,15 @@ def build(graph, target=None, shape=None, dtype="float32", params=None): ...@@ -189,6 +191,15 @@ def build(graph, target=None, shape=None, dtype="float32", params=None):
during inference time. Used for pre-compute during inference time. Used for pre-compute
folding optimization. folding optimization.
target_host : str or :any:`tvm.target.Target` optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Returns Returns
------- -------
graph : Graph graph : Graph
...@@ -228,6 +239,8 @@ def build(graph, target=None, shape=None, dtype="float32", params=None): ...@@ -228,6 +239,8 @@ def build(graph, target=None, shape=None, dtype="float32", params=None):
graph = graph_attr.set_shape_inputs(graph, shape) graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph_attr.set_dtype_inputs(graph, dtype) graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", str(target), "str") graph._set_json_attr("target", str(target), "str")
if target_host is not None:
graph._set_json_attr("target_host", str(target_host), "str")
if cfg.pass_enabled("OpFusion"): if cfg.pass_enabled("OpFusion"):
graph._set_json_attr("opt_level", 1, "int") graph._set_json_attr("opt_level", 1, "int")
else: else:
......
...@@ -219,6 +219,10 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) { ...@@ -219,6 +219,10 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
const std::vector<TOpPattern>& pattern_vec = const std::vector<TOpPattern>& pattern_vec =
g.GetAttr<std::vector<TOpPattern> >("pattern"); g.GetAttr<std::vector<TOpPattern> >("pattern");
std::string target = g.GetAttr<std::string>("target"); std::string target = g.GetAttr<std::string>("target");
std::string target_host;
if (g.HasAttr("target_host"))
target_host = g.GetAttr<std::string>("target_host");
std::vector<FuseEntry> fuse_vec(idx.num_nodes()); std::vector<FuseEntry> fuse_vec(idx.num_nodes());
// setup inputs and placeholder. // setup inputs and placeholder.
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
...@@ -398,7 +402,7 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) { ...@@ -398,7 +402,7 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
ret.attrs["dltype"] = std::make_shared<any>(std::move(new_dltype_vec)); ret.attrs["dltype"] = std::make_shared<any>(std::move(new_dltype_vec));
// Setup module // Setup module
static const PackedFunc& fbuild = GetPackedFunc("nnvm.compiler.build_target"); static const PackedFunc& fbuild = GetPackedFunc("nnvm.compiler.build_target");
tvm::runtime::Module module = fbuild(func_list, target); tvm::runtime::Module module = fbuild(func_list, target, target_host);
ret.attrs["module"] = std::make_shared<any>(std::move(module)); ret.attrs["module"] = std::make_shared<any>(std::move(module));
ret = nnvm::ApplyPass(ret, "PlanMemory"); ret = nnvm::ApplyPass(ret, "PlanMemory");
return ret; return ret;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment