Commit 647267da by Tianqi Chen

Allow libaray path to be configurable (#50)

* Allow libaray path to be configurable

* Enable partial shape inference result to be passed via shape

* fix python3

* disallow copy assign in index
parent bf8bd967
......@@ -171,6 +171,8 @@ class IndexedGraph {
inline const std::vector<NodeEntry>& outputs() const {
return outputs_;
}
// disalllow copy assign
IndexedGraph(const IndexedGraph&) = delete;
private:
friend class Graph;
......
# coding: utf-8
"""Information about nnvm."""
from __future__ import absolute_import
import sys
import os
import platform
if sys.version_info[0] == 3:
import builtins as __builtin__
else:
import __builtin__
def find_lib_path():
"""Find NNNet dynamic library files.
......@@ -12,10 +18,19 @@ def find_lib_path():
lib_path : list(string)
List of all found path to the libraries
"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
api_path = os.path.join(curr_path, '../../lib/')
cmake_build_path = os.path.join(curr_path, '../../build/Release/')
dll_path = [curr_path, api_path, cmake_build_path]
if hasattr(__builtin__, "NNVM_BASE_PATH"):
base_path = __builtin__.NNVM_BASE_PATH
else:
base_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
if hasattr(__builtin__, "NNVM_LIBRARY_NAME"):
lib_name = __builtin__.NNVM_LIBRARY_NAME
else:
lib_name = "libnnvm_example"
api_path = os.path.join(base_path, '../../lib/')
cmake_build_path = os.path.join(base_path, '../../build/Release/')
dll_path = [base_path, api_path, cmake_build_path]
if os.name == 'nt':
vs_configuration = 'Release'
if platform.architecture()[0] == '64bit':
......@@ -27,9 +42,9 @@ def find_lib_path():
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
if os.name == 'nt':
dll_path = [os.path.join(p, 'libnnvm_example.dll') for p in dll_path]
dll_path = [os.path.join(p, '%s.dll' % lib_name) for p in dll_path]
else:
dll_path = [os.path.join(p, 'libnnvm_example.so') for p in dll_path]
dll_path = [os.path.join(p, '%s.so' % lib_name) for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_path) == 0:
raise RuntimeError('Cannot find the files.\n' +
......
......@@ -57,7 +57,7 @@ class Symbol(SymbolBase):
def __mul__(self, other):
if isinstance(other, Symbol):
return _internal.__mul_symbol__(self, other)
if isinstance(other, Number):
if isinstance(other, _Number):
return _internal.__mul_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
......
......@@ -23,11 +23,16 @@ Graph InferAttr(Graph &&ret,
using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType>>(infer_name);
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& backward_map =
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
// reshape shape vector
AttrVector rshape(idx.num_node_entries(), default_val);
AttrVector rshape;
if (ret.attrs.count(attr_name) != 0) {
rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
} else {
rshape.resize(idx.num_node_entries(), default_val);
}
if (ret.attrs.count(input_name) != 0) {
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
......@@ -39,6 +44,7 @@ Graph InferAttr(Graph &&ret,
// erase the provided arguments
ret.attrs.erase(input_name);
}
std::string shape_attr_key;
if (ret.attrs.count(attr_key_name) != 0) {
shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
......
......@@ -59,6 +59,22 @@ def test_infer_shape():
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
def test_infer_shape_known_partial():
x = sym.Variable('x', shape=(4, 2))
y = sym.add(x, x, name='add1')
y = sym.reshape(y, target=(2, 4), name="reshape1")
g = graph.create(y)
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
shape = [[4, 2], [] , []]
g._set_json_attr("shape", shape, 'list_shape')
g = g.apply("InferShape")
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
def test_infer_type():
x = sym.Variable('x')
y = sym.add(x, x, name='add1')
......@@ -116,6 +132,7 @@ if __name__ == "__main__":
test_graph_json_attr()
test_json_pass()
test_infer_shape()
test_infer_shape_known_partial()
test_infer_type()
test_place_device()
test_plan_memory()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment