Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
146714ac
Commit
146714ac
authored
Jun 15, 2018
by
Tianqi Chen
Committed by
GitHub
Jun 15, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[CONTAINER] Introduce StrMap (#1292)
parent
c9703594
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
117 additions
and
14 deletions
+117
-14
CMakeLists.txt
+1
-1
HalideIR
+1
-1
include/tvm/packed_func_ext.h
+19
-0
python/tvm/_ffi/_ctypes/node.py
+1
-0
python/tvm/_ffi/node_generic.py
+5
-1
python/tvm/container.py
+14
-3
src/api/api_lang.cc
+53
-8
src/lang/expr.cc
+1
-0
tests/cpp/container_test.cc
+9
-0
tests/python/unittest/test_lang_container.py
+13
-0
No files found.
CMakeLists.txt
View file @
146714ac
...
...
@@ -196,7 +196,7 @@ if(GTEST_LIB)
add_executable
(
${
__execname
}
${
__srcpath
}
)
list
(
APPEND TEST_EXECS
${
__execname
}
)
target_link_libraries
(
${
__execname
}
tvm
${
GTEST_LIB
}
${
TVM_LINKER_LIBS
}
${
TVM_RUNTIME_LINKER_LIBS
}
pthread
)
tvm
${
GTEST_LIB
}
pthread
)
set_target_properties
(
${
__execname
}
PROPERTIES EXCLUDE_FROM_ALL 1
)
set_target_properties
(
${
__execname
}
PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1
)
endforeach
()
...
...
HalideIR
@
0b7e2527
Subproject commit
a3698398faff7fec1c0fa4e4479357651382db75
Subproject commit
0b7e25275138768bb05edb9b9db2c86d0fb09c9a
include/tvm/packed_func_ext.h
View file @
146714ac
...
...
@@ -60,6 +60,25 @@ struct NodeTypeChecker<Array<T> > {
}
};
template
<
typename
V
>
struct
NodeTypeChecker
<
Map
<
std
::
string
,
V
>
>
{
static
inline
bool
Check
(
Node
*
sptr
)
{
if
(
sptr
==
nullptr
)
return
false
;
if
(
!
sptr
->
is_type
<
StrMapNode
>
())
return
false
;
StrMapNode
*
n
=
static_cast
<
StrMapNode
*>
(
sptr
);
for
(
const
auto
&
kv
:
n
->
data
)
{
if
(
!
NodeTypeChecker
<
V
>::
Check
(
kv
.
second
.
get
()))
return
false
;
}
return
true
;
}
static
inline
void
PrintName
(
std
::
ostringstream
&
os
)
{
// NOLINT(*)
os
<<
"map<string"
;
os
<<
','
;
NodeTypeChecker
<
V
>::
PrintName
(
os
);
os
<<
'>'
;
}
};
template
<
typename
K
,
typename
V
>
struct
NodeTypeChecker
<
Map
<
K
,
V
>
>
{
static
inline
bool
Check
(
Node
*
sptr
)
{
...
...
python/tvm/_ffi/_ctypes/node.py
View file @
146714ac
...
...
@@ -30,6 +30,7 @@ RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node
C_TO_PY_ARG_SWITCH
[
TypeCode
.
NODE_HANDLE
]
=
_wrap_arg_func
(
_return_node
,
TypeCode
.
NODE_HANDLE
)
class
NodeBase
(
object
):
__slots__
=
[
"handle"
]
# pylint: disable=no-member
...
...
python/tvm/_ffi/node_generic.py
View file @
146714ac
...
...
@@ -13,12 +13,14 @@ def _set_class_node_base(cls):
global
_CLASS_NODE_BASE
_CLASS_NODE_BASE
=
cls
class
NodeGeneric
(
object
):
"""Base class for all classes that can be converted to node."""
def
asnode
(
self
):
"""Convert value to node"""
raise
NotImplementedError
()
def
convert_to_node
(
value
):
"""Convert a python value to corresponding node type.
...
...
@@ -46,7 +48,8 @@ def convert_to_node(value):
elif
isinstance
(
value
,
dict
):
vlist
=
[]
for
item
in
value
.
items
():
if
not
isinstance
(
item
[
0
],
_CLASS_NODE_BASE
):
if
(
not
isinstance
(
item
[
0
],
_CLASS_NODE_BASE
)
and
not
isinstance
(
item
[
0
],
string_types
)):
raise
ValueError
(
"key of map must already been a container type"
)
vlist
.
append
(
item
[
0
])
vlist
.
append
(
convert_to_node
(
item
[
1
]))
...
...
@@ -56,6 +59,7 @@ def convert_to_node(value):
else
:
raise
ValueError
(
"don't know how to convert type
%
s to node"
%
type
(
value
))
def
const
(
value
,
dtype
=
None
):
"""Construct a constant value for a given type.
...
...
python/tvm/container.py
View file @
146714ac
...
...
@@ -32,9 +32,8 @@ class Map(NodeBase):
"""Map container of TVM.
You do not need to create Map explicitly.
Normally python dict will be converted automatically
to Array during tvm function call.
You may get Map in return values of TVM function call.
Normally python dict will be converted automaticall to Map during tvm function call.
You can use convert to create a dict[NodeBase-> NodeBase] into a Map
"""
def
__getitem__
(
self
,
k
):
return
_api_internal
.
_MapGetItem
(
self
,
k
)
...
...
@@ -52,6 +51,18 @@ class Map(NodeBase):
@register_node
class
StrMap
(
Map
):
"""A special map container that has str as key.
You can use convert to create a dict[str->NodeBase] into a Map.
"""
def
items
(
self
):
"""Get the items from the map"""
akvs
=
_api_internal
.
_MapItems
(
self
)
return
[(
akvs
[
i
]
.
value
,
akvs
[
i
+
1
])
for
i
in
range
(
0
,
len
(
akvs
),
2
)]
@register_node
class
Range
(
NodeBase
):
"""Represent range in TVM.
...
...
src/api/api_lang.cc
View file @
146714ac
...
...
@@ -76,56 +76,92 @@ TVM_REGISTER_API("_ArraySize")
TVM_REGISTER_API
(
"_Map"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK_EQ
(
args
.
size
()
%
2
,
0
);
if
(
args
.
size
()
!=
0
&&
args
[
0
].
type_code
()
==
kStr
)
{
// StrMap
StrMapNode
::
ContainerType
data
;
for
(
int
i
=
0
;
i
<
args
.
num_args
;
i
+=
2
)
{
CHECK
(
args
[
i
].
type_code
()
==
kStr
)
<<
"key of str map need to be str"
;
CHECK
(
args
[
i
+
1
].
type_code
()
==
kNodeHandle
)
<<
"value of the map to be NodeRef"
;
data
.
emplace
(
std
::
make_pair
(
args
[
i
].
operator
std
::
string
(),
args
[
i
+
1
].
node_sptr
()));
}
auto
node
=
std
::
make_shared
<
StrMapNode
>
();
node
->
data
=
std
::
move
(
data
);
*
ret
=
node
;
}
else
{
// Container node.
MapNode
::
ContainerType
data
;
for
(
int
i
=
0
;
i
<
args
.
num_args
;
i
+=
2
)
{
CHECK
(
args
[
i
].
type_code
()
==
kNodeHandle
)
<<
"need content of array to be NodeBase
"
;
<<
"key of str map need to be str
"
;
CHECK
(
args
[
i
+
1
].
type_code
()
==
kNodeHandle
)
<<
"need content of array to be NodeBase
"
;
<<
"value of map to be NodeRef
"
;
data
.
emplace
(
std
::
make_pair
(
args
[
i
].
node_sptr
(),
args
[
i
+
1
].
node_sptr
()));
}
auto
node
=
std
::
make_shared
<
MapNode
>
();
node
->
data
=
std
::
move
(
data
);
*
ret
=
node
;
}
});
TVM_REGISTER_API
(
"_MapSize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
0
].
node_sptr
();
CHECK
(
sptr
->
is_type
<
MapNode
>
());
if
(
sptr
->
is_type
<
MapNode
>
())
{
auto
*
n
=
static_cast
<
const
MapNode
*>
(
sptr
.
get
());
*
ret
=
static_cast
<
int64_t
>
(
n
->
data
.
size
());
}
else
{
CHECK
(
sptr
->
is_type
<
StrMapNode
>
());
auto
*
n
=
static_cast
<
const
StrMapNode
*>
(
sptr
.
get
());
*
ret
=
static_cast
<
int64_t
>
(
n
->
data
.
size
());
}
});
TVM_REGISTER_API
(
"_MapGetItem"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
CHECK
(
args
[
1
].
type_code
()
==
kNodeHandle
);
auto
&
sptr
=
args
[
0
].
node_sptr
();
CHECK
(
sptr
->
is_type
<
MapNode
>
());
if
(
sptr
->
is_type
<
MapNode
>
())
{
CHECK
(
args
[
1
].
type_code
()
==
kNodeHandle
);
auto
*
n
=
static_cast
<
const
MapNode
*>
(
sptr
.
get
());
auto
it
=
n
->
data
.
find
(
args
[
1
].
node_sptr
());
CHECK
(
it
!=
n
->
data
.
end
())
<<
"cannot find the corresponding key in the Map"
;
*
ret
=
(
*
it
).
second
;
}
else
{
CHECK
(
sptr
->
is_type
<
StrMapNode
>
());
auto
*
n
=
static_cast
<
const
StrMapNode
*>
(
sptr
.
get
());
auto
it
=
n
->
data
.
find
(
args
[
1
].
operator
std
::
string
());
CHECK
(
it
!=
n
->
data
.
end
())
<<
"cannot find the corresponding key in the Map"
;
*
ret
=
(
*
it
).
second
;
}
});
TVM_REGISTER_API
(
"_MapCount"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
CHECK
(
args
[
0
].
type_code
()
==
kNodeHandle
);
CHECK
(
args
[
1
].
type_code
()
==
kNodeHandle
);
auto
&
sptr
=
args
[
0
].
node_sptr
();
CHECK
(
sptr
->
is_type
<
MapNode
>
());
if
(
sptr
->
is_type
<
MapNode
>
())
{
auto
*
n
=
static_cast
<
const
MapNode
*>
(
sptr
.
get
());
CHECK
(
args
[
1
].
type_code
()
==
kNodeHandle
);
*
ret
=
static_cast
<
int64_t
>
(
n
->
data
.
count
(
args
[
1
].
node_sptr
()));
}
else
{
CHECK
(
sptr
->
is_type
<
StrMapNode
>
());
auto
*
n
=
static_cast
<
const
StrMapNode
*>
(
sptr
.
get
());
*
ret
=
static_cast
<
int64_t
>
(
n
->
data
.
count
(
args
[
1
].
operator
std
::
string
()));
}
});
TVM_REGISTER_API
(
"_MapItems"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
&
sptr
=
args
[
0
].
node_sptr
();
CHECK
(
sptr
->
is_type
<
MapNode
>
());
if
(
sptr
->
is_type
<
MapNode
>
())
{
auto
*
n
=
static_cast
<
const
MapNode
*>
(
sptr
.
get
());
auto
rkvs
=
std
::
make_shared
<
ArrayNode
>
();
for
(
const
auto
&
kv
:
n
->
data
)
{
...
...
@@ -133,6 +169,15 @@ TVM_REGISTER_API("_MapItems")
rkvs
->
data
.
push_back
(
kv
.
second
);
}
*
ret
=
rkvs
;
}
else
{
auto
*
n
=
static_cast
<
const
StrMapNode
*>
(
sptr
.
get
());
auto
rkvs
=
std
::
make_shared
<
ArrayNode
>
();
for
(
const
auto
&
kv
:
n
->
data
)
{
rkvs
->
data
.
push_back
(
ir
::
StringImm
::
make
(
kv
.
first
).
node_
);
rkvs
->
data
.
push_back
(
kv
.
second
);
}
*
ret
=
rkvs
;
}
});
TVM_REGISTER_API
(
"Range"
)
...
...
src/lang/expr.cc
View file @
146714ac
...
...
@@ -74,6 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE
(
ArrayNode
);
TVM_REGISTER_NODE_TYPE
(
MapNode
);
TVM_REGISTER_NODE_TYPE
(
StrMapNode
);
TVM_REGISTER_NODE_TYPE
(
RangeNode
);
TVM_REGISTER_NODE_TYPE
(
IterVarNode
);
...
...
tests/cpp/container_test.cc
View file @
146714ac
...
...
@@ -35,6 +35,15 @@ TEST(Map, Expr) {
CHECK
(
!
dict
.
count
(
zz
));
}
TEST
(
StrMap
,
Expr
)
{
using
namespace
tvm
;
Var
x
(
"x"
);
auto
z
=
max
(
x
+
1
+
2
,
100
);
Map
<
std
::
string
,
Expr
>
dict
{{
"x"
,
z
},
{
"z"
,
2
}};
CHECK
(
dict
.
size
()
==
2
);
CHECK
(
dict
[
"x"
].
same_as
(
z
));
}
TEST
(
Map
,
Mutate
)
{
using
namespace
tvm
;
Var
x
(
"x"
);
...
...
tests/python/unittest/test_lang_container.py
View file @
146714ac
...
...
@@ -10,6 +10,7 @@ def test_array_save_load_json():
a_loaded
=
tvm
.
load_json
(
json_str
)
assert
(
a
[
1
]
.
value
==
2
)
def
test_map
():
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
...
...
@@ -22,6 +23,17 @@ def test_map():
assert
b
in
dd
assert
a
+
1
not
in
amap
def
test_str_map
():
amap
=
tvm
.
convert
({
'a'
:
2
,
'b'
:
3
})
assert
'a'
in
amap
assert
len
(
amap
)
==
2
dd
=
dict
(
amap
.
items
())
assert
amap
[
'a'
]
.
value
==
2
assert
'a'
in
dd
assert
'b'
in
dd
def
test_map_save_load_json
():
a
=
tvm
.
var
(
'a'
)
b
=
tvm
.
var
(
'b'
)
...
...
@@ -35,6 +47,7 @@ def test_map_save_load_json():
if
__name__
==
"__main__"
:
test_str_map
()
test_array
()
test_map
()
test_array_save_load_json
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment