Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a6724b6e
Unverified
Commit
a6724b6e
authored
Sep 15, 2018
by
Tianqi Chen
Committed by
GitHub
Sep 15, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NODE] Enable EnvFunc to serialize global function as node (#1721)
parent
43126602
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
240 additions
and
2 deletions
+240
-2
include/tvm/api_registry.h
+113
-2
include/tvm/runtime/packed_func.h
+8
-0
python/tvm/api.py
+22
-0
python/tvm/container.py
+14
-0
src/api/api_test.cc
+4
-0
src/lang/api_registry.cc
+50
-0
tests/python/unittest/test_lang_reflection.py
+29
-0
No files found.
include/tvm/api_registry.h
View file @
a6724b6e
/*!
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017 by Contributors
* \file tvm/api_registry.h
* \file tvm/api_registry.h
* \brief This file
s include necessary headers
to
* \brief This file
contains utilities related
to
*
be used to register an global API function
.
*
the TVM's global function registry
.
*/
*/
#ifndef TVM_API_REGISTRY_H_
#ifndef TVM_API_REGISTRY_H_
#define TVM_API_REGISTRY_H_
#define TVM_API_REGISTRY_H_
#include <string>
#include "base.h"
#include "base.h"
#include "packed_func_ext.h"
#include "packed_func_ext.h"
#include "runtime/registry.h"
#include "runtime/registry.h"
namespace
tvm
{
/*!
/*!
* \brief Register an API function globally.
* \brief Register an API function globally.
* It simply redirects to TVM_REGISTER_GLOBAL
* It simply redirects to TVM_REGISTER_GLOBAL
...
@@ -24,4 +26,113 @@
...
@@ -24,4 +26,113 @@
*/
*/
#define TVM_REGISTER_API(OpName) TVM_REGISTER_GLOBAL(OpName)
#define TVM_REGISTER_API(OpName) TVM_REGISTER_GLOBAL(OpName)
/*!
* \brief Node container of EnvFunc
* \sa EnvFunc
*/
class
EnvFuncNode
:
public
Node
{
public
:
/*! \brief Unique name of the global function */
std
::
string
name
;
/*! \brief The internal packed function */
PackedFunc
func
;
/*! \brief constructor */
EnvFuncNode
()
{}
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name"
,
&
name
);
}
static
constexpr
const
char
*
_type_key
=
"EnvFunc"
;
TVM_DECLARE_NODE_TYPE_INFO
(
EnvFuncNode
,
Node
);
};
/*!
* \brief A serializable function backed by TVM's global environment.
*
* This is a wrapper to enable serializable global PackedFunc.
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
*/
class
EnvFunc
:
public
NodeRef
{
public
:
EnvFunc
()
{}
explicit
EnvFunc
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*! \return The internal global function pointer */
const
EnvFuncNode
*
operator
->
()
const
{
return
static_cast
<
EnvFuncNode
*>
(
node_
.
get
());
}
/*!
* \brief Invoke the function.
* \param args The arguments
* \returns The return value.
*/
template
<
typename
...
Args
>
runtime
::
TVMRetValue
operator
()(
Args
&&
...
args
)
const
{
const
EnvFuncNode
*
n
=
operator
->
();
CHECK
(
n
!=
nullptr
);
return
n
->
func
(
std
::
forward
<
Args
>
(
args
)...);
}
/*!
* \brief Get a global function based on the name.
* \param name The name of the global function.
* \return The created global function.
* \note The function can be unique
*/
TVM_DLL
static
EnvFunc
Get
(
const
std
::
string
&
name
);
/*! \brief specify container node */
using
ContainerType
=
EnvFuncNode
;
};
/*!
* \brief Please refer to \ref TypedEnvFuncAnchor "TypedEnvFunc<R(Args..)>"
*/
template
<
typename
FType
>
class
TypedEnvFunc
;
/*!
* \anchor TypedEnvFuncAnchor
* \brief A typed version of EnvFunc.
* It is backed by a GlobalFuncNode internally.
*
* \tparam R The return value of the function.
* \tparam Args The argument signature of the function.
* \sa EnvFunc
*/
template
<
typename
R
,
typename
...
Args
>
class
TypedEnvFunc
<
R
(
Args
...)
>
:
public
NodeRef
{
public
:
/*! \brief short hand for this function type */
using
TSelf
=
TypedEnvFunc
<
R
(
Args
...)
>
;
TypedEnvFunc
()
{}
explicit
TypedEnvFunc
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
* \brief Assign global function to a TypedEnvFunc
* \param other Another global function.
* \return reference to self.
*/
TSelf
&
operator
=
(
const
EnvFunc
&
other
)
{
this
->
node_
=
other
.
node_
;
return
*
this
;
}
/*! \return The internal global function pointer */
const
EnvFuncNode
*
operator
->
()
const
{
return
static_cast
<
EnvFuncNode
*>
(
node_
.
get
());
}
/*!
* \brief Invoke the function.
* \param args The arguments
* \returns The return value.
*/
R
operator
()(
Args
...
args
)
const
{
const
EnvFuncNode
*
n
=
operator
->
();
CHECK
(
n
!=
nullptr
);
return
runtime
::
detail
::
typed_packed_call_dispatcher
<
R
>
::
run
(
n
->
func
,
std
::
forward
<
Args
>
(
args
)...);
}
/*! \brief specify container node */
using
ContainerType
=
EnvFuncNode
;
};
}
// namespace tvm
#endif // TVM_API_REGISTRY_H_
#endif // TVM_API_REGISTRY_H_
include/tvm/runtime/packed_func.h
View file @
a6724b6e
...
@@ -257,6 +257,14 @@ class TypedPackedFunc<R(Args...)> {
...
@@ -257,6 +257,14 @@ class TypedPackedFunc<R(Args...)> {
const
PackedFunc
&
packed
()
const
{
const
PackedFunc
&
packed
()
const
{
return
packed_
;
return
packed_
;
}
}
/*! \return Whether the packed function is nullptr */
bool
operator
==
(
std
::
nullptr_t
null
)
const
{
return
packed_
==
nullptr
;
}
/*! \return Whether the packed function is not nullptr */
bool
operator
!=
(
std
::
nullptr_t
null
)
const
{
return
packed_
!=
nullptr
;
}
private
:
private
:
friend
class
TVMRetValue
;
friend
class
TVMRetValue
;
...
...
python/tvm/api.py
View file @
a6724b6e
...
@@ -45,6 +45,28 @@ def const(value, dtype=None):
...
@@ -45,6 +45,28 @@ def const(value, dtype=None):
return
_api_internal
.
_const
(
value
,
dtype
)
return
_api_internal
.
_const
(
value
,
dtype
)
def
get_env_func
(
name
):
"""Get an EnvFunc by a global name.
Parameters
----------
name: str
The name of the global function.
Returns
-------
env_func : EnvFunc
The result env function.
Note
----
EnvFunc is a Node wrapper around
global function that can be serialized via its name.
This can be used to serialize function field in the language.
"""
return
_api_internal
.
_EnvFuncGet
(
name
)
def
convert
(
value
):
def
convert
(
value
):
"""Convert value to TVM node or function.
"""Convert value to TVM node or function.
...
...
python/tvm/container.py
View file @
a6724b6e
...
@@ -28,6 +28,20 @@ class Array(NodeBase):
...
@@ -28,6 +28,20 @@ class Array(NodeBase):
@register_node
@register_node
class
EnvFunc
(
NodeBase
):
"""Environment function.
This is a global function object that can be serialized by its name.
"""
def
__call__
(
self
,
*
args
):
return
_api_internal
.
_EnvFuncCall
(
self
,
*
args
)
@property
def
func
(
self
):
return
_api_internal
.
_EnvFuncGetPackedFunc
(
self
)
@register_node
class
Map
(
NodeBase
):
class
Map
(
NodeBase
):
"""Map container of TVM.
"""Map container of TVM.
...
...
src/api/api_test.cc
View file @
a6724b6e
...
@@ -14,6 +14,7 @@ struct TestAttrs : public AttrsNode<TestAttrs> {
...
@@ -14,6 +14,7 @@ struct TestAttrs : public AttrsNode<TestAttrs> {
int
axis
;
int
axis
;
std
::
string
name
;
std
::
string
name
;
Array
<
Expr
>
padding
;
Array
<
Expr
>
padding
;
TypedEnvFunc
<
int
(
int
)
>
func
;
TVM_DECLARE_ATTRS
(
TestAttrs
,
"attrs.TestAttrs"
)
{
TVM_DECLARE_ATTRS
(
TestAttrs
,
"attrs.TestAttrs"
)
{
TVM_ATTR_FIELD
(
axis
)
TVM_ATTR_FIELD
(
axis
)
...
@@ -26,6 +27,9 @@ struct TestAttrs : public AttrsNode<TestAttrs> {
...
@@ -26,6 +27,9 @@ struct TestAttrs : public AttrsNode<TestAttrs> {
TVM_ATTR_FIELD
(
padding
)
TVM_ATTR_FIELD
(
padding
)
.
describe
(
"padding of input"
)
.
describe
(
"padding of input"
)
.
set_default
(
Array
<
Expr
>
({
0
,
0
}));
.
set_default
(
Array
<
Expr
>
({
0
,
0
}));
TVM_ATTR_FIELD
(
func
)
.
describe
(
"some random env function"
)
.
set_default
(
TypedEnvFunc
<
int
(
int
)
>
(
nullptr
));
}
}
};
};
...
...
src/lang/api_registry.cc
0 → 100644
View file @
a6724b6e
/*!
* Copyright (c) 2018 by Contributors
* \file api_registry.cc
*/
#include <tvm/api_registry.h>
namespace
tvm
{
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
EnvFuncNode
>
([](
const
EnvFuncNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"EnvFunc("
<<
op
->
name
<<
")"
;
});
std
::
shared_ptr
<
EnvFuncNode
>
CreateEnvNode
(
const
std
::
string
&
name
)
{
auto
*
f
=
runtime
::
Registry
::
Get
(
name
);
CHECK
(
f
!=
nullptr
)
<<
"Cannot find global function
\'
"
<<
name
<<
'\''
;
std
::
shared_ptr
<
EnvFuncNode
>
n
=
std
::
make_shared
<
EnvFuncNode
>
();
n
->
func
=
*
f
;
n
->
name
=
name
;
return
n
;
}
EnvFunc
EnvFunc
::
Get
(
const
std
::
string
&
name
)
{
return
EnvFunc
(
CreateEnvNode
(
name
));
}
TVM_REGISTER_API
(
"_EnvFuncGet"
)
.
set_body_typed
<
EnvFunc
(
const
std
::
string
&
name
)
>
(
EnvFunc
::
Get
);
TVM_REGISTER_API
(
"_EnvFuncCall"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
EnvFunc
env
=
args
[
0
];
CHECK_GE
(
args
.
size
(),
1
);
env
->
func
.
CallPacked
(
TVMArgs
(
args
.
values
+
1
,
args
.
type_codes
+
1
,
args
.
size
()
-
1
),
rv
);
});
TVM_REGISTER_API
(
"_EnvFuncGetPackedFunc"
)
.
set_body_typed
<
PackedFunc
(
const
EnvFunc
&
n
)
>
([](
const
EnvFunc
&
n
)
{
return
n
->
func
;
});
TVM_REGISTER_NODE_TYPE
(
EnvFuncNode
)
.
set_creator
(
CreateEnvNode
)
.
set_global_key
([](
const
Node
*
n
)
{
return
static_cast
<
const
EnvFuncNode
*>
(
n
)
->
name
;
});
}
// namespace tvm
tests/python/unittest/test_lang_reflection.py
View file @
a6724b6e
...
@@ -56,11 +56,14 @@ def test_make_attrs():
...
@@ -56,11 +56,14 @@ def test_make_attrs():
assert
x
.
padding
[
1
]
.
value
==
4
assert
x
.
padding
[
1
]
.
value
==
4
assert
x
.
axis
==
10
assert
x
.
axis
==
10
dattr
=
tvm
.
make
.
node
(
"DictAttrs"
,
x
=
1
,
y
=
10
,
name
=
"xyz"
,
padding
=
(
0
,
0
))
dattr
=
tvm
.
make
.
node
(
"DictAttrs"
,
x
=
1
,
y
=
10
,
name
=
"xyz"
,
padding
=
(
0
,
0
))
assert
dattr
.
x
.
value
==
1
assert
dattr
.
x
.
value
==
1
datrr
=
tvm
.
load_json
(
tvm
.
save_json
(
dattr
))
datrr
=
tvm
.
load_json
(
tvm
.
save_json
(
dattr
))
assert
dattr
.
name
.
value
==
"xyz"
assert
dattr
.
name
.
value
==
"xyz"
def
test_make_sum
():
def
test_make_sum
():
A
=
tvm
.
placeholder
((
2
,
10
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
2
,
10
),
name
=
'A'
)
k
=
tvm
.
reduce_axis
((
0
,
10
),
"k"
)
k
=
tvm
.
reduce_axis
((
0
,
10
),
"k"
)
...
@@ -70,7 +73,33 @@ def test_make_sum():
...
@@ -70,7 +73,33 @@ def test_make_sum():
assert
B
.
op
.
body
[
0
]
.
combiner
is
not
None
assert
B
.
op
.
body
[
0
]
.
combiner
is
not
None
assert
BB
.
op
.
body
[
0
]
.
combiner
is
not
None
assert
BB
.
op
.
body
[
0
]
.
combiner
is
not
None
def
test_env_func
():
@tvm.register_func
(
"test.env_func"
)
def
test
(
x
):
return
x
+
1
f
=
tvm
.
get_global_func
(
"test.env_func"
)
x
=
tvm
.
get_env_func
(
"test.env_func"
)
assert
x
.
name
==
"test.env_func"
json_str
=
tvm
.
save_json
([
x
])
y
=
tvm
.
load_json
(
json_str
)[
0
]
assert
y
.
name
==
x
.
name
assert
y
(
1
)
==
2
assert
y
.
func
(
1
)
==
2
x
=
tvm
.
make
.
node
(
"attrs.TestAttrs"
,
name
=
"xx"
,
padding
=
(
3
,
4
),
func
=
y
)
assert
x
.
name
==
"xx"
assert
x
.
padding
[
0
]
.
value
==
3
assert
x
.
padding
[
1
]
.
value
==
4
assert
x
.
axis
==
10
x
=
tvm
.
load_json
(
tvm
.
save_json
(
x
))
assert
isinstance
(
x
.
func
,
tvm
.
container
.
EnvFunc
)
assert
x
.
func
(
10
)
==
11
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_env_func
()
test_make_attrs
()
test_make_attrs
()
test_make_node
()
test_make_node
()
test_make_smap
()
test_make_smap
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment