Commit 647267da by Tianqi Chen

Allow libaray path to be configurable (#50)

* Allow libaray path to be configurable

* Enable partial shape inference result to be passed via shape

* fix python3

* disallow copy assign in index
parent bf8bd967
...@@ -171,6 +171,8 @@ class IndexedGraph { ...@@ -171,6 +171,8 @@ class IndexedGraph {
inline const std::vector<NodeEntry>& outputs() const { inline const std::vector<NodeEntry>& outputs() const {
return outputs_; return outputs_;
} }
// disalllow copy assign
IndexedGraph(const IndexedGraph&) = delete;
private: private:
friend class Graph; friend class Graph;
......
# coding: utf-8 # coding: utf-8
"""Information about nnvm.""" """Information about nnvm."""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import os import os
import platform import platform
if sys.version_info[0] == 3:
import builtins as __builtin__
else:
import __builtin__
def find_lib_path(): def find_lib_path():
"""Find NNNet dynamic library files. """Find NNNet dynamic library files.
...@@ -12,10 +18,19 @@ def find_lib_path(): ...@@ -12,10 +18,19 @@ def find_lib_path():
lib_path : list(string) lib_path : list(string)
List of all found path to the libraries List of all found path to the libraries
""" """
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) if hasattr(__builtin__, "NNVM_BASE_PATH"):
api_path = os.path.join(curr_path, '../../lib/') base_path = __builtin__.NNVM_BASE_PATH
cmake_build_path = os.path.join(curr_path, '../../build/Release/') else:
dll_path = [curr_path, api_path, cmake_build_path] base_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
if hasattr(__builtin__, "NNVM_LIBRARY_NAME"):
lib_name = __builtin__.NNVM_LIBRARY_NAME
else:
lib_name = "libnnvm_example"
api_path = os.path.join(base_path, '../../lib/')
cmake_build_path = os.path.join(base_path, '../../build/Release/')
dll_path = [base_path, api_path, cmake_build_path]
if os.name == 'nt': if os.name == 'nt':
vs_configuration = 'Release' vs_configuration = 'Release'
if platform.architecture()[0] == '64bit': if platform.architecture()[0] == '64bit':
...@@ -27,9 +42,9 @@ def find_lib_path(): ...@@ -27,9 +42,9 @@ def find_lib_path():
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None): elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")]) dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
if os.name == 'nt': if os.name == 'nt':
dll_path = [os.path.join(p, 'libnnvm_example.dll') for p in dll_path] dll_path = [os.path.join(p, '%s.dll' % lib_name) for p in dll_path]
else: else:
dll_path = [os.path.join(p, 'libnnvm_example.so') for p in dll_path] dll_path = [os.path.join(p, '%s.so' % lib_name) for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_path) == 0: if len(lib_path) == 0:
raise RuntimeError('Cannot find the files.\n' + raise RuntimeError('Cannot find the files.\n' +
......
...@@ -57,7 +57,7 @@ class Symbol(SymbolBase): ...@@ -57,7 +57,7 @@ class Symbol(SymbolBase):
def __mul__(self, other): def __mul__(self, other):
if isinstance(other, Symbol): if isinstance(other, Symbol):
return _internal.__mul_symbol__(self, other) return _internal.__mul_symbol__(self, other)
if isinstance(other, Number): if isinstance(other, _Number):
return _internal.__mul_scalar__(self, scalar=other) return _internal.__mul_scalar__(self, scalar=other)
else: else:
raise TypeError('type %s not supported' % str(type(other))) raise TypeError('type %s not supported' % str(type(other)))
......
...@@ -23,11 +23,16 @@ Graph InferAttr(Graph &&ret, ...@@ -23,11 +23,16 @@ Graph InferAttr(Graph &&ret,
using AttrVector = std::vector<AttrType>; using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph(); const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape = static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType>>(infer_name); Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& backward_map = static auto& backward_map =
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex"); Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
// reshape shape vector // reshape shape vector
AttrVector rshape(idx.num_node_entries(), default_val); AttrVector rshape;
if (ret.attrs.count(attr_name) != 0) {
rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
} else {
rshape.resize(idx.num_node_entries(), default_val);
}
if (ret.attrs.count(input_name) != 0) { if (ret.attrs.count(input_name) != 0) {
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name); const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
...@@ -39,6 +44,7 @@ Graph InferAttr(Graph &&ret, ...@@ -39,6 +44,7 @@ Graph InferAttr(Graph &&ret,
// erase the provided arguments // erase the provided arguments
ret.attrs.erase(input_name); ret.attrs.erase(input_name);
} }
std::string shape_attr_key; std::string shape_attr_key;
if (ret.attrs.count(attr_key_name) != 0) { if (ret.attrs.count(attr_key_name) != 0) {
shape_attr_key = ret.GetAttr<std::string>(attr_key_name); shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
......
...@@ -59,6 +59,22 @@ def test_infer_shape(): ...@@ -59,6 +59,22 @@ def test_infer_shape():
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4] assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2] assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
def test_infer_shape_known_partial():
x = sym.Variable('x', shape=(4, 2))
y = sym.add(x, x, name='add1')
y = sym.reshape(y, target=(2, 4), name="reshape1")
g = graph.create(y)
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
shape = [[4, 2], [] , []]
g._set_json_attr("shape", shape, 'list_shape')
g = g.apply("InferShape")
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
def test_infer_type(): def test_infer_type():
x = sym.Variable('x') x = sym.Variable('x')
y = sym.add(x, x, name='add1') y = sym.add(x, x, name='add1')
...@@ -116,6 +132,7 @@ if __name__ == "__main__": ...@@ -116,6 +132,7 @@ if __name__ == "__main__":
test_graph_json_attr() test_graph_json_attr()
test_json_pass() test_json_pass()
test_infer_shape() test_infer_shape()
test_infer_shape_known_partial()
test_infer_type() test_infer_type()
test_place_device() test_place_device()
test_plan_memory() test_plan_memory()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment