Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
1f7712ae
Commit
1f7712ae
authored
Jul 20, 2017
by
Tianqi Chen
Committed by
GitHub
Jul 20, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[LANG] Add reflection routine to construct node (#265)
parent
68c4400e
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
195 additions
and
19 deletions
+195
-19
include/tvm/target_info.h
+40
-0
python/tvm/make.py
+27
-0
src/api/api_base.cc
+1
-1
src/lang/expr.cc
+0
-4
src/lang/reflection.cc
+78
-2
src/lang/target_info.cc
+19
-0
tests/python/unittest/test_lang_basic.py
+0
-12
tests/python/unittest/test_lang_reflection.py
+30
-0
No files found.
include/tvm/target_info.h
0 → 100644
View file @
1f7712ae
/*!
* Copyright (c) 2017 by Contributors
* \file target_info.h
* \brief Various information about target.
*/
#ifndef TVM_TARGET_INFO_H_
#define TVM_TARGET_INFO_H_
#include "./base.h"
#include "./expr.h"
namespace
tvm
{
/*!
* \brief Memory information of special memory region.
* Use MemoryInfo as its container type
*/
struct
MemoryInfoNode
:
public
Node
{
/*! \brief The addressable unit */
int
unit_bits
;
/*! \brief Maximum number of bits supported in the memory */
int
max_num_bits
;
/*! \brief maximum number of bits to be used in simd op */
int
max_simd_bits
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"unit_bits"
,
&
unit_bits
);
v
->
Visit
(
"max_num_bits"
,
&
max_num_bits
);
v
->
Visit
(
"max_simd_bits"
,
&
max_simd_bits
);
}
static
constexpr
const
char
*
_type_key
=
"MemoryInfo"
;
TVM_DECLARE_NODE_TYPE_INFO
(
MemoryInfoNode
,
Node
);
};
/*! \brief Defines memory info */
TVM_DEFINE_NODE_REF
(
MemoryInfo
,
MemoryInfoNode
);
}
// namespace tvm
#endif // TVM_TARGET_INFO_H_
python/tvm/make.py
View file @
1f7712ae
...
...
@@ -30,6 +30,33 @@ def range_by_min_extent(min_value, extent):
return
_range_by_min_extent
(
min_value
,
extent
)
def
node
(
type_key
,
**
kwargs
):
"""Make a new DSL node by its type key and fields
Parameters
----------
type_key : str
The type key of the node.
**kwargs : dict
The fields of the node.
Example
-------
The following code constructs a IntImm object
.. code-block:: python
x = tvm.make.node("IntImm", dtype="int32", value=10)
assert isinstance(x, tvm.expr.IntImm)
assert x.value == 10
"""
args
=
[
type_key
]
for
k
,
v
in
kwargs
.
items
():
args
+=
[
k
,
v
]
return
_Node
(
*
args
)
def
stmt_seq
(
*
args
):
"""Make sequence of statements
...
...
src/api/api_base.cc
View file @
1f7712ae
...
...
@@ -30,7 +30,7 @@ TVM_REGISTER_API("_save_json")
TVM_REGISTER_API
(
"_load_json"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
NodeRef
(
LoadJSON_
(
args
[
0
])
);
*
ret
=
LoadJSON
<
NodeRef
>
(
args
[
0
]
);
});
TVM_REGISTER_API
(
"_nop"
)
...
...
src/lang/expr.cc
View file @
1f7712ae
...
...
@@ -8,10 +8,6 @@
#include <ir/IRPrinter.h>
#include <memory>
namespace
dmlc
{
DMLC_REGISTRY_ENABLE
(
::
tvm
::
NodeFactoryReg
);
}
// namespace dmlc
namespace
tvm
{
using
Halide
::
IR
::
RangeNode
;
...
...
src/lang/
saveload_js
on.cc
→
src/lang/
reflecti
on.cc
View file @
1f7712ae
/*!
* Copyright (c) 2016 by Contributors
* \file
saveload_js
on.cc
* \brief Utilities to save/load
TVM objects.
* \file
reflecti
on.cc
* \brief Utilities to save/load
/construct TVM objects
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/container.h>
#include <tvm/packed_func_ext.h>
#include <dmlc/json.h>
#include <string>
namespace
dmlc
{
DMLC_REGISTRY_ENABLE
(
::
tvm
::
NodeFactoryReg
);
}
// namespace dmlc
namespace
tvm
{
inline
std
::
string
Type2String
(
const
Type
&
t
)
{
...
...
@@ -334,4 +339,75 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) {
return
nodes
.
at
(
jgraph
.
root
);
}
class
NodeAttrSetter
:
public
AttrVisitor
{
public
:
std
::
string
type_key
;
std
::
unordered_map
<
std
::
string
,
runtime
::
TVMArgValue
>
attrs
;
template
<
typename
T
>
void
SetValue
(
const
char
*
key
,
T
*
value
)
{
auto
it
=
attrs
.
find
(
key
);
if
(
it
==
attrs
.
end
())
{
LOG
(
FATAL
)
<<
type_key
<<
": require field "
<<
key
;
}
*
value
=
it
->
second
.
operator
T
();
attrs
.
erase
(
it
);
}
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{
SetValue
(
key
,
value
);
}
void
Visit
(
const
char
*
key
,
int64_t
*
value
)
final
{
SetValue
(
key
,
value
);
}
void
Visit
(
const
char
*
key
,
uint64_t
*
value
)
final
{
SetValue
(
key
,
value
);
}
void
Visit
(
const
char
*
key
,
int
*
value
)
final
{
SetValue
(
key
,
value
);
}
void
Visit
(
const
char
*
key
,
bool
*
value
)
final
{
SetValue
(
key
,
value
);
}
void
Visit
(
const
char
*
key
,
std
::
string
*
value
)
final
{
SetValue
(
key
,
value
);
}
void
Visit
(
const
char
*
key
,
Type
*
value
)
final
{
SetValue
(
key
,
value
);
}
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
SetValue
(
key
,
value
);
}
};
// API function to make node.
// args format:
// type_key, key1, value1, ..., key_n, value_n
void
MakeNode
(
runtime
::
TVMArgs
args
,
runtime
::
TVMRetValue
*
rv
)
{
NodeAttrSetter
setter
;
setter
.
type_key
=
args
[
0
].
operator
std
::
string
();
CHECK_EQ
(
args
.
size
()
%
2
,
1
);
for
(
int
i
=
1
;
i
<
args
.
size
();
i
+=
2
)
{
setter
.
attrs
.
emplace
(
args
[
i
].
operator
std
::
string
(),
runtime
::
TVMArgValue
(
args
.
values
[
i
+
1
],
args
.
type_codes
[
i
+
1
]));
}
auto
*
f
=
dmlc
::
Registry
<
NodeFactoryReg
>::
Find
(
setter
.
type_key
);
CHECK
(
f
!=
nullptr
)
<<
"Node type
\'
"
<<
setter
.
type_key
<<
"
\'
is not registered in TVM"
;
std
::
shared_ptr
<
Node
>
n
=
f
->
body
();
n
->
VisitAttrs
(
&
setter
);
if
(
setter
.
attrs
.
size
()
!=
0
)
{
std
::
ostringstream
os
;
os
<<
setter
.
type_key
<<
" does not contain field "
;
for
(
const
auto
&
kv
:
setter
.
attrs
)
{
os
<<
" "
<<
kv
.
first
;
}
LOG
(
FATAL
)
<<
os
.
str
();
}
*
rv
=
NodeRef
(
n
);
}
TVM_REGISTER_GLOBAL
(
"make._Node"
)
.
set_body
(
MakeNode
);
}
// namespace tvm
src/lang/target_info.cc
0 → 100644
View file @
1f7712ae
/*!
* Copyright (c) 2017 by Contributors
* \file target_info.cc
*/
#include <tvm/target_info.h>
namespace
tvm
{
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
MemoryInfoNode
>
([](
const
MemoryInfoNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"mem-info("
<<
"unit_bits="
<<
op
->
unit_bits
<<
", "
<<
"max_num_bits="
<<
op
->
max_num_bits
<<
", "
<<
"max_simd_bits="
<<
op
->
max_simd_bits
<<
")"
;
});
TVM_REGISTER_NODE_TYPE
(
MemoryInfoNode
);
}
// namespace tvm
tests/python/unittest/test_lang_basic.py
View file @
1f7712ae
...
...
@@ -6,21 +6,10 @@ def test_const():
assert
x
.
dtype
==
tvm
.
int32
assert
isinstance
(
x
,
tvm
.
expr
.
IntImm
)
def
test_const_saveload_json
():
# save load json
x
=
tvm
.
const
(
1
)
y
=
tvm
.
const
(
10
)
z
=
x
+
y
z
=
z
+
z
json_str
=
tvm
.
save_json
(
z
)
zz
=
tvm
.
load_json
(
json_str
)
assert
tvm
.
save_json
(
zz
)
==
tvm
.
save_json
(
z
)
def
test_make
():
x
=
tvm
.
const
(
1
)
y
=
tvm
.
make
.
IntImm
(
'int32'
,
1
)
z
=
x
+
y
print
(
z
)
def
test_ir
():
x
=
tvm
.
const
(
1
)
...
...
@@ -129,7 +118,6 @@ def test_all():
if
__name__
==
"__main__"
:
test_attr
()
test_const
()
test_const_saveload_json
()
test_make
()
test_ir
()
test_basic
()
...
...
tests/python/unittest/test_lang_reflection.py
0 → 100644
View file @
1f7712ae
import
tvm
def
test_const_saveload_json
():
# save load json
x
=
tvm
.
const
(
1
)
y
=
tvm
.
const
(
10
)
z
=
x
+
y
z
=
z
+
z
json_str
=
tvm
.
save_json
(
z
)
zz
=
tvm
.
load_json
(
json_str
)
assert
tvm
.
save_json
(
zz
)
==
tvm
.
save_json
(
z
)
def
test_make_node
():
x
=
tvm
.
make
.
node
(
"IntImm"
,
dtype
=
"int32"
,
value
=
10
)
assert
isinstance
(
x
,
tvm
.
expr
.
IntImm
)
assert
x
.
value
==
10
A
=
tvm
.
placeholder
((
10
,
),
name
=
'A'
)
AA
=
tvm
.
make
.
node
(
"Tensor"
,
shape
=
A
.
shape
,
dtype
=
A
.
dtype
,
op
=
A
.
op
,
value_index
=
A
.
value_index
)
assert
AA
.
op
==
A
.
op
assert
AA
.
value_index
==
A
.
value_index
if
__name__
==
"__main__"
:
test_make_node
()
test_const_saveload_json
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment