Commit 146714ac by Tianqi Chen Committed by GitHub

[CONTAINER] Introduce StrMap (#1292)

parent c9703594
...@@ -196,7 +196,7 @@ if(GTEST_LIB) ...@@ -196,7 +196,7 @@ if(GTEST_LIB)
add_executable(${__execname} ${__srcpath}) add_executable(${__execname} ${__srcpath})
list(APPEND TEST_EXECS ${__execname}) list(APPEND TEST_EXECS ${__execname})
target_link_libraries(${__execname} target_link_libraries(${__execname}
tvm ${GTEST_LIB} ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS} pthread) tvm ${GTEST_LIB} pthread)
set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1) set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1)
set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1)
endforeach() endforeach()
......
Subproject commit a3698398faff7fec1c0fa4e4479357651382db75 Subproject commit 0b7e25275138768bb05edb9b9db2c86d0fb09c9a
...@@ -60,6 +60,25 @@ struct NodeTypeChecker<Array<T> > { ...@@ -60,6 +60,25 @@ struct NodeTypeChecker<Array<T> > {
} }
}; };
template<typename V>
struct NodeTypeChecker<Map<std::string, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<StrMapNode>()) return false;
StrMapNode* n = static_cast<StrMapNode*>(sptr);
for (const auto& kv : n->data) {
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<string";
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};
template<typename K, typename V> template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > { struct NodeTypeChecker<Map<K, V> > {
static inline bool Check(Node* sptr) { static inline bool Check(Node* sptr) {
......
...@@ -30,6 +30,7 @@ RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node ...@@ -30,6 +30,7 @@ RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
_return_node, TypeCode.NODE_HANDLE) _return_node, TypeCode.NODE_HANDLE)
class NodeBase(object): class NodeBase(object):
__slots__ = ["handle"] __slots__ = ["handle"]
# pylint: disable=no-member # pylint: disable=no-member
......
...@@ -13,12 +13,14 @@ def _set_class_node_base(cls): ...@@ -13,12 +13,14 @@ def _set_class_node_base(cls):
global _CLASS_NODE_BASE global _CLASS_NODE_BASE
_CLASS_NODE_BASE = cls _CLASS_NODE_BASE = cls
class NodeGeneric(object): class NodeGeneric(object):
"""Base class for all classes that can be converted to node.""" """Base class for all classes that can be converted to node."""
def asnode(self): def asnode(self):
"""Convert value to node""" """Convert value to node"""
raise NotImplementedError() raise NotImplementedError()
def convert_to_node(value): def convert_to_node(value):
"""Convert a python value to corresponding node type. """Convert a python value to corresponding node type.
...@@ -46,7 +48,8 @@ def convert_to_node(value): ...@@ -46,7 +48,8 @@ def convert_to_node(value):
elif isinstance(value, dict): elif isinstance(value, dict):
vlist = [] vlist = []
for item in value.items(): for item in value.items():
if not isinstance(item[0], _CLASS_NODE_BASE): if (not isinstance(item[0], _CLASS_NODE_BASE) and
not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type") raise ValueError("key of map must already been a container type")
vlist.append(item[0]) vlist.append(item[0])
vlist.append(convert_to_node(item[1])) vlist.append(convert_to_node(item[1]))
...@@ -56,6 +59,7 @@ def convert_to_node(value): ...@@ -56,6 +59,7 @@ def convert_to_node(value):
else: else:
raise ValueError("don't know how to convert type %s to node" % type(value)) raise ValueError("don't know how to convert type %s to node" % type(value))
def const(value, dtype=None): def const(value, dtype=None):
"""Construct a constant value for a given type. """Construct a constant value for a given type.
......
...@@ -32,9 +32,8 @@ class Map(NodeBase): ...@@ -32,9 +32,8 @@ class Map(NodeBase):
"""Map container of TVM. """Map container of TVM.
You do not need to create Map explicitly. You do not need to create Map explicitly.
Normally python dict will be converted automatically Normally python dict will be converted automaticall to Map during tvm function call.
to Array during tvm function call. You can use convert to create a dict[NodeBase-> NodeBase] into a Map
You may get Map in return values of TVM function call.
""" """
def __getitem__(self, k): def __getitem__(self, k):
return _api_internal._MapGetItem(self, k) return _api_internal._MapGetItem(self, k)
...@@ -52,6 +51,18 @@ class Map(NodeBase): ...@@ -52,6 +51,18 @@ class Map(NodeBase):
@register_node @register_node
class StrMap(Map):
"""A special map container that has str as key.
You can use convert to create a dict[str->NodeBase] into a Map.
"""
def items(self):
"""Get the items from the map"""
akvs = _api_internal._MapItems(self)
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
@register_node
class Range(NodeBase): class Range(NodeBase):
"""Represent range in TVM. """Represent range in TVM.
......
...@@ -76,56 +76,92 @@ TVM_REGISTER_API("_ArraySize") ...@@ -76,56 +76,92 @@ TVM_REGISTER_API("_ArraySize")
TVM_REGISTER_API("_Map") TVM_REGISTER_API("_Map")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0); CHECK_EQ(args.size() % 2, 0);
if (args.size() != 0 && args[0].type_code() == kStr) {
// StrMap
StrMapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kStr)
<< "key of str map need to be str";
CHECK(args[i + 1].type_code() == kNodeHandle)
<< "value of the map to be NodeRef";
data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].node_sptr()));
}
auto node = std::make_shared<StrMapNode>();
node->data = std::move(data);
*ret = node;
} else {
// Container node.
MapNode::ContainerType data; MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) { for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kNodeHandle) CHECK(args[i].type_code() == kNodeHandle)
<< "need content of array to be NodeBase"; << "key of str map need to be str";
CHECK(args[i + 1].type_code() == kNodeHandle) CHECK(args[i + 1].type_code() == kNodeHandle)
<< "need content of array to be NodeBase"; << "value of map to be NodeRef";
data.emplace(std::make_pair(args[i].node_sptr(), data.emplace(std::make_pair(args[i].node_sptr(),
args[i + 1].node_sptr())); args[i + 1].node_sptr()));
} }
auto node = std::make_shared<MapNode>(); auto node = std::make_shared<MapNode>();
node->data = std::move(data); node->data = std::move(data);
*ret = node; *ret = node;
}
}); });
TVM_REGISTER_API("_MapSize") TVM_REGISTER_API("_MapSize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr(); auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>()); if (sptr->is_type<MapNode>()) {
auto* n = static_cast<const MapNode*>(sptr.get()); auto* n = static_cast<const MapNode*>(sptr.get());
*ret = static_cast<int64_t>(n->data.size()); *ret = static_cast<int64_t>(n->data.size());
} else {
CHECK(sptr->is_type<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(sptr.get());
*ret = static_cast<int64_t>(n->data.size());
}
}); });
TVM_REGISTER_API("_MapGetItem") TVM_REGISTER_API("_MapGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args[0].type_code() == kNodeHandle); CHECK(args[0].type_code() == kNodeHandle);
CHECK(args[1].type_code() == kNodeHandle);
auto& sptr = args[0].node_sptr(); auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>()); if (sptr->is_type<MapNode>()) {
CHECK(args[1].type_code() == kNodeHandle);
auto* n = static_cast<const MapNode*>(sptr.get()); auto* n = static_cast<const MapNode*>(sptr.get());
auto it = n->data.find(args[1].node_sptr()); auto it = n->data.find(args[1].node_sptr());
CHECK(it != n->data.end()) CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map"; << "cannot find the corresponding key in the Map";
*ret = (*it).second; *ret = (*it).second;
} else {
CHECK(sptr->is_type<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(sptr.get());
auto it = n->data.find(args[1].operator std::string());
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
*ret = (*it).second;
}
}); });
TVM_REGISTER_API("_MapCount") TVM_REGISTER_API("_MapCount")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args[0].type_code() == kNodeHandle); CHECK(args[0].type_code() == kNodeHandle);
CHECK(args[1].type_code() == kNodeHandle);
auto& sptr = args[0].node_sptr(); auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>()); if (sptr->is_type<MapNode>()) {
auto* n = static_cast<const MapNode*>(sptr.get()); auto* n = static_cast<const MapNode*>(sptr.get());
CHECK(args[1].type_code() == kNodeHandle);
*ret = static_cast<int64_t>( *ret = static_cast<int64_t>(
n->data.count(args[1].node_sptr())); n->data.count(args[1].node_sptr()));
} else {
CHECK(sptr->is_type<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(sptr.get());
*ret = static_cast<int64_t>(
n->data.count(args[1].operator std::string()));
}
}); });
TVM_REGISTER_API("_MapItems") TVM_REGISTER_API("_MapItems")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr(); auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>()); if (sptr->is_type<MapNode>()) {
auto* n = static_cast<const MapNode*>(sptr.get()); auto* n = static_cast<const MapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>(); auto rkvs = std::make_shared<ArrayNode>();
for (const auto& kv : n->data) { for (const auto& kv : n->data) {
...@@ -133,6 +169,15 @@ TVM_REGISTER_API("_MapItems") ...@@ -133,6 +169,15 @@ TVM_REGISTER_API("_MapItems")
rkvs->data.push_back(kv.second); rkvs->data.push_back(kv.second);
} }
*ret = rkvs; *ret = rkvs;
} else {
auto* n = static_cast<const StrMapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(ir::StringImm::make(kv.first).node_);
rkvs->data.push_back(kv.second);
}
*ret = rkvs;
}
}); });
TVM_REGISTER_API("Range") TVM_REGISTER_API("Range")
......
...@@ -74,6 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -74,6 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ArrayNode); TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode); TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(StrMapNode);
TVM_REGISTER_NODE_TYPE(RangeNode); TVM_REGISTER_NODE_TYPE(RangeNode);
TVM_REGISTER_NODE_TYPE(IterVarNode); TVM_REGISTER_NODE_TYPE(IterVarNode);
......
...@@ -35,6 +35,15 @@ TEST(Map, Expr) { ...@@ -35,6 +35,15 @@ TEST(Map, Expr) {
CHECK(!dict.count(zz)); CHECK(!dict.count(zz));
} }
TEST(StrMap, Expr) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
Map<std::string, Expr> dict{{"x", z}, {"z", 2}};
CHECK(dict.size() == 2);
CHECK(dict["x"].same_as(z));
}
TEST(Map, Mutate) { TEST(Map, Mutate) {
using namespace tvm; using namespace tvm;
Var x("x"); Var x("x");
......
...@@ -10,6 +10,7 @@ def test_array_save_load_json(): ...@@ -10,6 +10,7 @@ def test_array_save_load_json():
a_loaded = tvm.load_json(json_str) a_loaded = tvm.load_json(json_str)
assert(a[1].value == 2) assert(a[1].value == 2)
def test_map(): def test_map():
a = tvm.var('a') a = tvm.var('a')
b = tvm.var('b') b = tvm.var('b')
...@@ -22,6 +23,17 @@ def test_map(): ...@@ -22,6 +23,17 @@ def test_map():
assert b in dd assert b in dd
assert a + 1 not in amap assert a + 1 not in amap
def test_str_map():
amap = tvm.convert({'a': 2, 'b': 3})
assert 'a' in amap
assert len(amap) == 2
dd = dict(amap.items())
assert amap['a'].value == 2
assert 'a' in dd
assert 'b' in dd
def test_map_save_load_json(): def test_map_save_load_json():
a = tvm.var('a') a = tvm.var('a')
b = tvm.var('b') b = tvm.var('b')
...@@ -35,6 +47,7 @@ def test_map_save_load_json(): ...@@ -35,6 +47,7 @@ def test_map_save_load_json():
if __name__ == "__main__": if __name__ == "__main__":
test_str_map()
test_array() test_array()
test_map() test_map()
test_array_save_load_json() test_array_save_load_json()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment