Commit 146714ac by Tianqi Chen Committed by GitHub

[CONTAINER] Introduce StrMap (#1292)

parent c9703594
......@@ -196,7 +196,7 @@ if(GTEST_LIB)
add_executable(${__execname} ${__srcpath})
list(APPEND TEST_EXECS ${__execname})
target_link_libraries(${__execname}
tvm ${GTEST_LIB} ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS} pthread)
tvm ${GTEST_LIB} pthread)
set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1)
set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1)
endforeach()
......
Subproject commit a3698398faff7fec1c0fa4e4479357651382db75
Subproject commit 0b7e25275138768bb05edb9b9db2c86d0fb09c9a
......@@ -60,6 +60,25 @@ struct NodeTypeChecker<Array<T> > {
}
};
template<typename V>
struct NodeTypeChecker<Map<std::string, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<StrMapNode>()) return false;
StrMapNode* n = static_cast<StrMapNode*>(sptr);
for (const auto& kv : n->data) {
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<string";
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline bool Check(Node* sptr) {
......
......@@ -30,6 +30,7 @@ RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
_return_node, TypeCode.NODE_HANDLE)
class NodeBase(object):
__slots__ = ["handle"]
# pylint: disable=no-member
......
......@@ -13,12 +13,14 @@ def _set_class_node_base(cls):
global _CLASS_NODE_BASE
_CLASS_NODE_BASE = cls
class NodeGeneric(object):
"""Base class for all classes that can be converted to node."""
def asnode(self):
"""Convert value to node"""
raise NotImplementedError()
def convert_to_node(value):
"""Convert a python value to corresponding node type.
......@@ -46,7 +48,8 @@ def convert_to_node(value):
elif isinstance(value, dict):
vlist = []
for item in value.items():
if not isinstance(item[0], _CLASS_NODE_BASE):
if (not isinstance(item[0], _CLASS_NODE_BASE) and
not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type")
vlist.append(item[0])
vlist.append(convert_to_node(item[1]))
......@@ -56,6 +59,7 @@ def convert_to_node(value):
else:
raise ValueError("don't know how to convert type %s to node" % type(value))
def const(value, dtype=None):
"""Construct a constant value for a given type.
......
......@@ -32,9 +32,8 @@ class Map(NodeBase):
"""Map container of TVM.
You do not need to create Map explicitly.
Normally python dict will be converted automatically
to Array during tvm function call.
You may get Map in return values of TVM function call.
Normally python dict will be converted automaticall to Map during tvm function call.
You can use convert to create a dict[NodeBase-> NodeBase] into a Map
"""
def __getitem__(self, k):
return _api_internal._MapGetItem(self, k)
......@@ -52,6 +51,18 @@ class Map(NodeBase):
@register_node
class StrMap(Map):
"""A special map container that has str as key.
You can use convert to create a dict[str->NodeBase] into a Map.
"""
def items(self):
"""Get the items from the map"""
akvs = _api_internal._MapItems(self)
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
@register_node
class Range(NodeBase):
"""Represent range in TVM.
......
......@@ -76,56 +76,92 @@ TVM_REGISTER_API("_ArraySize")
TVM_REGISTER_API("_Map")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
if (args.size() != 0 && args[0].type_code() == kStr) {
// StrMap
StrMapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kStr)
<< "key of str map need to be str";
CHECK(args[i + 1].type_code() == kNodeHandle)
<< "value of the map to be NodeRef";
data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].node_sptr()));
}
auto node = std::make_shared<StrMapNode>();
node->data = std::move(data);
*ret = node;
} else {
// Container node.
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kNodeHandle)
<< "need content of array to be NodeBase";
<< "key of str map need to be str";
CHECK(args[i + 1].type_code() == kNodeHandle)
<< "need content of array to be NodeBase";
<< "value of map to be NodeRef";
data.emplace(std::make_pair(args[i].node_sptr(),
args[i + 1].node_sptr()));
}
auto node = std::make_shared<MapNode>();
node->data = std::move(data);
*ret = node;
}
});
TVM_REGISTER_API("_MapSize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
if (sptr->is_type<MapNode>()) {
auto* n = static_cast<const MapNode*>(sptr.get());
*ret = static_cast<int64_t>(n->data.size());
} else {
CHECK(sptr->is_type<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(sptr.get());
*ret = static_cast<int64_t>(n->data.size());
}
});
TVM_REGISTER_API("_MapGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args[0].type_code() == kNodeHandle);
CHECK(args[1].type_code() == kNodeHandle);
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
if (sptr->is_type<MapNode>()) {
CHECK(args[1].type_code() == kNodeHandle);
auto* n = static_cast<const MapNode*>(sptr.get());
auto it = n->data.find(args[1].node_sptr());
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
*ret = (*it).second;
} else {
CHECK(sptr->is_type<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(sptr.get());
auto it = n->data.find(args[1].operator std::string());
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
*ret = (*it).second;
}
});
TVM_REGISTER_API("_MapCount")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args[0].type_code() == kNodeHandle);
CHECK(args[1].type_code() == kNodeHandle);
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
if (sptr->is_type<MapNode>()) {
auto* n = static_cast<const MapNode*>(sptr.get());
CHECK(args[1].type_code() == kNodeHandle);
*ret = static_cast<int64_t>(
n->data.count(args[1].node_sptr()));
} else {
CHECK(sptr->is_type<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(sptr.get());
*ret = static_cast<int64_t>(
n->data.count(args[1].operator std::string()));
}
});
TVM_REGISTER_API("_MapItems")
.set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
if (sptr->is_type<MapNode>()) {
auto* n = static_cast<const MapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>();
for (const auto& kv : n->data) {
......@@ -133,6 +169,15 @@ TVM_REGISTER_API("_MapItems")
rkvs->data.push_back(kv.second);
}
*ret = rkvs;
} else {
auto* n = static_cast<const StrMapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(ir::StringImm::make(kv.first).node_);
rkvs->data.push_back(kv.second);
}
*ret = rkvs;
}
});
TVM_REGISTER_API("Range")
......
......@@ -74,6 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(StrMapNode);
TVM_REGISTER_NODE_TYPE(RangeNode);
TVM_REGISTER_NODE_TYPE(IterVarNode);
......
......@@ -35,6 +35,15 @@ TEST(Map, Expr) {
CHECK(!dict.count(zz));
}
TEST(StrMap, Expr) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
Map<std::string, Expr> dict{{"x", z}, {"z", 2}};
CHECK(dict.size() == 2);
CHECK(dict["x"].same_as(z));
}
TEST(Map, Mutate) {
using namespace tvm;
Var x("x");
......
......@@ -10,6 +10,7 @@ def test_array_save_load_json():
a_loaded = tvm.load_json(json_str)
assert(a[1].value == 2)
def test_map():
a = tvm.var('a')
b = tvm.var('b')
......@@ -22,6 +23,17 @@ def test_map():
assert b in dd
assert a + 1 not in amap
def test_str_map():
amap = tvm.convert({'a': 2, 'b': 3})
assert 'a' in amap
assert len(amap) == 2
dd = dict(amap.items())
assert amap['a'].value == 2
assert 'a' in dd
assert 'b' in dd
def test_map_save_load_json():
a = tvm.var('a')
b = tvm.var('b')
......@@ -35,6 +47,7 @@ def test_map_save_load_json():
if __name__ == "__main__":
test_str_map()
test_array()
test_map()
test_array_save_load_json()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment