Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
aaf7ff04
Commit
aaf7ff04
authored
Apr 01, 2018
by
alex-weaver
Committed by
Tianqi Chen
Apr 01, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Move BuildConfig context stack to C++ (#1025)
parent
7b098c9a
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
198 additions
and
39 deletions
+198
-39
include/tvm/build_module.h
+65
-1
python/tvm/build_module.py
+44
-35
python/tvm/tensor_intrin.py
+2
-2
src/codegen/build_module.cc
+86
-1
tests/python/unittest/test_pass_unroll.py
+1
-0
No files found.
include/tvm/build_module.h
View file @
aaf7ff04
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include <utility>
#include "./runtime/packed_func.h"
#include "./runtime/packed_func.h"
#include "./schedule_pass.h"
#include "./schedule_pass.h"
#include "./lowered_func.h"
#include "./lowered_func.h"
...
@@ -203,6 +204,12 @@ class BuildConfigNode : public Node {
...
@@ -203,6 +204,12 @@ class BuildConfigNode : public Node {
/*! \brief Whether to partition const loop */
/*! \brief Whether to partition const loop */
bool
partition_const_loop
=
false
;
bool
partition_const_loop
=
false
;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
std
::
vector
<
std
::
pair
<
int
,
PackedFunc
>
>
add_lower_pass
;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool
dump_pass_ir
=
false
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"data_alignment"
,
&
data_alignment
);
v
->
Visit
(
"data_alignment"
,
&
data_alignment
);
v
->
Visit
(
"offset_factor"
,
&
offset_factor
);
v
->
Visit
(
"offset_factor"
,
&
offset_factor
);
...
@@ -214,13 +221,70 @@ class BuildConfigNode : public Node {
...
@@ -214,13 +221,70 @@ class BuildConfigNode : public Node {
v
->
Visit
(
"restricted_func"
,
&
restricted_func
);
v
->
Visit
(
"restricted_func"
,
&
restricted_func
);
v
->
Visit
(
"detect_global_barrier"
,
&
detect_global_barrier
);
v
->
Visit
(
"detect_global_barrier"
,
&
detect_global_barrier
);
v
->
Visit
(
"partition_const_loop"
,
&
partition_const_loop
);
v
->
Visit
(
"partition_const_loop"
,
&
partition_const_loop
);
v
->
Visit
(
"dump_pass_ir"
,
&
dump_pass_ir
);
}
}
static
constexpr
const
char
*
_type_key
=
"BuildConfig"
;
static
constexpr
const
char
*
_type_key
=
"BuildConfig"
;
TVM_DECLARE_NODE_TYPE_INFO
(
BuildConfigNode
,
Node
);
TVM_DECLARE_NODE_TYPE_INFO
(
BuildConfigNode
,
Node
);
};
};
TVM_DEFINE_NODE_REF
(
BuildConfig
,
BuildConfigNode
);
/*!
* \brief Container for build configuration options
*/
class
BuildConfig
:
public
::
tvm
::
NodeRef
{
public
:
BuildConfig
()
{}
explicit
BuildConfig
(
std
::
shared_ptr
<::
tvm
::
Node
>
n
)
:
NodeRef
(
n
)
{}
const
BuildConfigNode
*
operator
->
()
const
{
return
static_cast
<
const
BuildConfigNode
*>
(
node_
.
get
());
}
BuildConfigNode
*
operator
->
()
{
return
static_cast
<
BuildConfigNode
*>
(
node_
.
get
());
}
/*!
* \brief Push a new BuildConfig context onto the thread local stack.
* \param build_config The configuration to set as the current context.
*/
EXPORT
static
void
EnterBuildConfigScope
(
const
tvm
::
BuildConfig
&
build_config
);
/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
EXPORT
static
void
ExitBuildConfigScope
();
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
*/
EXPORT
static
tvm
::
BuildConfig
Current
();
using
ContainerType
=
BuildConfigNode
;
};
/*!
* \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
* context stack when constructed, and pops it when destructed.
*/
struct
BuildConfigContext
{
/*!
* \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
* context. When the BuildConfigContext is destructed, the previous context is restored.
* \param build_config The BuildConfig to set as the new current context.
*/
explicit
BuildConfigContext
(
const
tvm
::
BuildConfig
&
build_config
)
{
BuildConfig
::
EnterBuildConfigScope
(
build_config
);
}
/*! \brief Destructor. Pops the context off the thread local stack. */
~
BuildConfigContext
()
{
BuildConfig
::
ExitBuildConfigScope
();
}
};
/*!
/*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \brief Construct a BuildConfig containing a new BuildConfigNode
...
...
python/tvm/build_module.py
View file @
aaf7ff04
...
@@ -8,8 +8,8 @@ import warnings
...
@@ -8,8 +8,8 @@ import warnings
import
types
import
types
from
._ffi.node
import
NodeBase
,
register_node
from
._ffi.node
import
NodeBase
,
register_node
from
._ffi.base
import
_RUNTIME_ONLY
from
.
import
api
from
.
import
api
from
.
import
_api_internal
from
.
import
tensor
from
.
import
tensor
from
.
import
schedule
from
.
import
schedule
from
.
import
expr
from
.
import
expr
...
@@ -46,7 +46,8 @@ class DumpIR(object):
...
@@ -46,7 +46,8 @@ class DumpIR(object):
retv
=
func
(
*
args
,
**
kwargs
)
retv
=
func
(
*
args
,
**
kwargs
)
if
not
isinstance
(
retv
,
(
_stmt
.
Stmt
,
container
.
LoweredFunc
,
container
.
Array
)):
if
not
isinstance
(
retv
,
(
_stmt
.
Stmt
,
container
.
LoweredFunc
,
container
.
Array
)):
return
retv
return
retv
pname
=
str
(
self
.
_pass_id
)
+
"_"
+
func
.
func_name
+
"_ir.cc"
fname
=
func
.
func_name
if
hasattr
(
func
,
'func_name'
)
else
func
.
__name__
pname
=
str
(
self
.
_pass_id
)
+
"_"
+
fname
+
"_ir.cc"
with
open
(
pname
,
"a"
)
as
f
:
with
open
(
pname
,
"a"
)
as
f
:
out
=
retv
.
body
if
isinstance
(
retv
,
container
.
LoweredFunc
)
else
retv
out
=
retv
.
body
if
isinstance
(
retv
,
container
.
LoweredFunc
)
else
retv
f
.
write
(
str
(
out
))
f
.
write
(
str
(
out
))
...
@@ -70,20 +71,20 @@ class DumpIR(object):
...
@@ -70,20 +71,20 @@ class DumpIR(object):
self
.
_recover_list
.
append
(
recover
)
self
.
_recover_list
.
append
(
recover
)
vset
[
k
]
=
self
.
decorate
(
v
)
if
isinstance
(
v
,
types
.
FunctionType
)
else
v
vset
[
k
]
=
self
.
decorate
(
v
)
if
isinstance
(
v
,
types
.
FunctionType
)
else
v
def
decorate_custompass
(
self
):
def
decorate_custompass
(
self
,
custom_pass
):
""" decorate add_lower_pass pass in BuildConfig"""
"""decorate given list of custom passes, and return decorated passes"""
cfg
=
BuildConfig
.
current
custom_pass
=
custom_pass
if
custom_pass
else
[]
self
.
_old_custom_pass
=
cfg
.
add_lower_pass
pass_list
=
[]
custom_pass
=
cfg
.
add_lower_pass
if
cfg
.
add_lower_pass
else
[]
for
idx
,
x
in
enumerate
(
custom_pass
):
pass_list
=
[(
x
[
0
],
self
.
decorate
(
x
[
1
]))
for
x
in
custom_pass
]
x
[
1
]
.
__name__
=
"custom{}_phase{}"
.
format
(
idx
,
x
[
0
])
BuildConfig
.
current
.
add_lower_pass
=
pass_list
pass_list
+=
[(
x
[
0
],
self
.
decorate
(
x
[
1
]))]
return
pass_list
def
enter
(
self
):
def
enter
(
self
):
"""only decorate outermost nest"""
"""only decorate outermost nest"""
if
DumpIR
.
scope_level
>
0
:
if
DumpIR
.
scope_level
>
0
:
return
return
self
.
decorate_irpass
()
self
.
decorate_irpass
()
self
.
decorate_custompass
()
self
.
_pass_id
=
0
self
.
_pass_id
=
0
DumpIR
.
scope_level
+=
1
DumpIR
.
scope_level
+=
1
...
@@ -95,7 +96,6 @@ class DumpIR(object):
...
@@ -95,7 +96,6 @@ class DumpIR(object):
for
f
in
self
.
_recover_list
:
for
f
in
self
.
_recover_list
:
f
()
f
()
schedule
.
ScheduleOps
=
self
.
_old_sgpass
schedule
.
ScheduleOps
=
self
.
_old_sgpass
BuildConfig
.
current
.
add_lower_pass
=
self
.
_old_custom_pass
DumpIR
.
scope_level
-=
1
DumpIR
.
scope_level
-=
1
@register_node
@register_node
...
@@ -113,7 +113,6 @@ class BuildConfig(NodeBase):
...
@@ -113,7 +113,6 @@ class BuildConfig(NodeBase):
is constructed. See _node_defaults for the fields.
is constructed. See _node_defaults for the fields.
"""
"""
current
=
None
_node_defaults
=
{
_node_defaults
=
{
"auto_unroll_max_step"
:
0
,
"auto_unroll_max_step"
:
0
,
"auto_unroll_max_depth"
:
8
,
"auto_unroll_max_depth"
:
8
,
...
@@ -124,8 +123,10 @@ class BuildConfig(NodeBase):
...
@@ -124,8 +123,10 @@ class BuildConfig(NodeBase):
"offset_factor"
:
0
,
"offset_factor"
:
0
,
"data_alignment"
:
-
1
,
"data_alignment"
:
-
1
,
"restricted_func"
:
True
,
"restricted_func"
:
True
,
"double_buffer_split_loop"
:
1
"double_buffer_split_loop"
:
1
,
"dump_pass_ir"
:
False
}
}
_dump_ir
=
DumpIR
()
# pylint: disable=no-member
# pylint: disable=no-member
def
__init__
(
self
,
handle
):
def
__init__
(
self
,
handle
):
...
@@ -138,24 +139,28 @@ class BuildConfig(NodeBase):
...
@@ -138,24 +139,28 @@ class BuildConfig(NodeBase):
"""
"""
super
(
BuildConfig
,
self
)
.
__init__
(
handle
)
super
(
BuildConfig
,
self
)
.
__init__
(
handle
)
self
.
handle
=
handle
self
.
handle
=
handle
self
.
_old_scope
=
None
self
.
_dump_ir
=
DumpIR
()
@property
self
.
dump_pass_ir
=
False
def
add_lower_pass
(
self
):
self
.
add_lower_pass
=
None
size
=
_api_internal
.
_BuildConfigGetAddLowerPassInfo
(
self
)
result
=
[]
for
i
in
range
(
size
):
phase
=
_api_internal
.
_BuildConfigGetAddLowerPassInfo
(
self
,
i
,
True
)
func
=
_api_internal
.
_BuildConfigGetAddLowerPassInfo
(
self
,
i
,
False
)
result
+=
[(
phase
,
func
)]
return
result
def
__enter__
(
self
):
def
__enter__
(
self
):
# pylint: disable=protected-access
# pylint: disable=protected-access
self
.
_old_scope
=
BuildConfig
.
current
_api_internal
.
_EnterBuildConfigScope
(
self
)
BuildConfig
.
current
=
self
if
self
.
dump_pass_ir
:
if
self
.
dump_pass_ir
is
True
:
BuildConfig
.
_dump_ir
.
enter
()
self
.
_dump_ir
.
enter
()
return
self
return
self
def
__exit__
(
self
,
ptype
,
value
,
trace
):
def
__exit__
(
self
,
ptype
,
value
,
trace
):
assert
self
.
_old_scope
if
self
.
dump_pass_ir
:
if
self
.
dump_pass_ir
is
True
:
BuildConfig
.
_dump_ir
.
exit
()
self
.
_dump_ir
.
exit
()
_api_internal
.
_ExitBuildConfigScope
()
BuildConfig
.
current
=
self
.
_old_scope
def
__setattr__
(
self
,
name
,
value
):
def
__setattr__
(
self
,
name
,
value
):
if
name
in
BuildConfig
.
_node_defaults
:
if
name
in
BuildConfig
.
_node_defaults
:
...
@@ -163,6 +168,9 @@ class BuildConfig(NodeBase):
...
@@ -163,6 +168,9 @@ class BuildConfig(NodeBase):
"'
%
s' object cannot set attribute '
%
s'"
%
(
str
(
type
(
self
)),
name
))
"'
%
s' object cannot set attribute '
%
s'"
%
(
str
(
type
(
self
)),
name
))
return
super
(
BuildConfig
,
self
)
.
__setattr__
(
name
,
value
)
return
super
(
BuildConfig
,
self
)
.
__setattr__
(
name
,
value
)
def
current_build_config
():
return
_api_internal
.
_GetCurrentBuildConfig
()
def
build_config
(
**
kwargs
):
def
build_config
(
**
kwargs
):
"""Configure the build behavior by setting config variables.
"""Configure the build behavior by setting config variables.
...
@@ -221,14 +229,13 @@ def build_config(**kwargs):
...
@@ -221,14 +229,13 @@ def build_config(**kwargs):
for
k
,
v
in
BuildConfig
.
_node_defaults
.
items
()}
for
k
,
v
in
BuildConfig
.
_node_defaults
.
items
()}
config
=
make
.
node
(
"BuildConfig"
,
**
node_args
)
config
=
make
.
node
(
"BuildConfig"
,
**
node_args
)
for
k
in
kwargs
:
if
"add_lower_pass"
in
kwargs
:
if
not
k
in
node_args
:
add_lower_pass_args
=
[]
setattr
(
config
,
k
,
kwargs
[
k
])
for
x
in
kwargs
[
"add_lower_pass"
]:
return
config
add_lower_pass_args
+=
[
x
[
0
],
x
[
1
]]
_api_internal
.
_BuildConfigSetAddLowerPass
(
config
,
*
add_lower_pass_args
)
if
not
_RUNTIME_ONLY
:
return
config
# BuildConfig is not available in tvm_runtime
BuildConfig
.
current
=
build_config
()
def
get_binds
(
args
,
binds
=
None
):
def
get_binds
(
args
,
binds
=
None
):
"""Internal function to get binds and arg_list given arguments.
"""Internal function to get binds and arg_list given arguments.
...
@@ -252,7 +259,7 @@ def get_binds(args, binds=None):
...
@@ -252,7 +259,7 @@ def get_binds(args, binds=None):
The list of symbolic buffers of arguments.
The list of symbolic buffers of arguments.
"""
"""
binds
=
{}
if
binds
is
None
else
binds
.
copy
()
binds
=
{}
if
binds
is
None
else
binds
.
copy
()
cfg
=
BuildConfig
.
current
cfg
=
current_build_config
()
arg_list
=
[]
arg_list
=
[]
for
x
in
args
:
for
x
in
args
:
if
isinstance
(
x
,
tensor
.
Tensor
):
if
isinstance
(
x
,
tensor
.
Tensor
):
...
@@ -309,8 +316,10 @@ def lower(sch,
...
@@ -309,8 +316,10 @@ def lower(sch,
Then the Stmt before make api is returned.
Then the Stmt before make api is returned.
"""
"""
binds
,
arg_list
=
get_binds
(
args
,
binds
)
binds
,
arg_list
=
get_binds
(
args
,
binds
)
cfg
=
BuildConfig
.
current
cfg
=
current_build_config
()
add_lower_pass
=
cfg
.
add_lower_pass
if
cfg
.
add_lower_pass
else
[]
add_lower_pass
=
cfg
.
add_lower_pass
if
cfg
.
add_lower_pass
else
[]
if
cfg
.
dump_pass_ir
:
add_lower_pass
=
BuildConfig
.
_dump_ir
.
decorate_custompass
(
add_lower_pass
)
lower_phase0
=
[
x
[
1
]
for
x
in
add_lower_pass
if
x
[
0
]
==
0
]
lower_phase0
=
[
x
[
1
]
for
x
in
add_lower_pass
if
x
[
0
]
==
0
]
lower_phase1
=
[
x
[
1
]
for
x
in
add_lower_pass
if
x
[
0
]
==
1
]
lower_phase1
=
[
x
[
1
]
for
x
in
add_lower_pass
if
x
[
0
]
==
1
]
lower_phase2
=
[
x
[
1
]
for
x
in
add_lower_pass
if
x
[
0
]
==
2
]
lower_phase2
=
[
x
[
1
]
for
x
in
add_lower_pass
if
x
[
0
]
==
2
]
...
@@ -434,7 +443,7 @@ def build(sch,
...
@@ -434,7 +443,7 @@ def build(sch,
"Direct host side access to device memory is detected in
%
s. "
"Direct host side access to device memory is detected in
%
s. "
"Did you forget to bind?"
%
func
.
name
)
"Did you forget to bind?"
%
func
.
name
)
if
func
.
func_type
==
container
.
LoweredFunc
.
MixedFunc
:
if
func
.
func_type
==
container
.
LoweredFunc
.
MixedFunc
:
if
BuildConfig
.
current
.
detect_global_barrier
:
if
current_build_config
()
.
detect_global_barrier
:
func
=
ir_pass
.
ThreadSync
(
func
,
"global"
)
func
=
ir_pass
.
ThreadSync
(
func
,
"global"
)
func
=
ir_pass
.
ThreadSync
(
func
,
"shared"
)
func
=
ir_pass
.
ThreadSync
(
func
,
"shared"
)
warp_size
=
target
.
thread_warp_size
warp_size
=
target
.
thread_warp_size
...
...
python/tvm/tensor_intrin.py
View file @
aaf7ff04
...
@@ -6,7 +6,7 @@ from . import expr as _expr
...
@@ -6,7 +6,7 @@ from . import expr as _expr
from
.
import
stmt
as
_stmt
from
.
import
stmt
as
_stmt
from
.
import
make
as
_make
from
.
import
make
as
_make
from
.
import
tensor
as
_tensor
from
.
import
tensor
as
_tensor
from
.build_module
import
BuildC
onfig
from
.build_module
import
current_build_c
onfig
from
._ffi.node
import
NodeBase
,
register_node
from
._ffi.node
import
NodeBase
,
register_node
@register_node
@register_node
...
@@ -74,7 +74,7 @@ def decl_tensor_intrin(op,
...
@@ -74,7 +74,7 @@ def decl_tensor_intrin(op,
if
not
isinstance
(
t
.
op
,
_tensor
.
PlaceholderOp
):
if
not
isinstance
(
t
.
op
,
_tensor
.
PlaceholderOp
):
raise
ValueError
(
"Donot yet support composition op"
)
raise
ValueError
(
"Donot yet support composition op"
)
cfg
=
BuildConfig
.
current
cfg
=
current_build_config
()
for
t
in
tensors
:
for
t
in
tensors
:
buf
=
(
binds
[
t
]
if
t
in
binds
else
buf
=
(
binds
[
t
]
if
t
in
binds
else
_api
.
decl_buffer
(
t
.
shape
,
t
.
dtype
,
t
.
op
.
name
,
_api
.
decl_buffer
(
t
.
shape
,
t
.
dtype
,
t
.
op
.
name
,
...
...
src/codegen/build_module.cc
View file @
aaf7ff04
...
@@ -468,6 +468,41 @@ BuildConfig build_config() {
...
@@ -468,6 +468,41 @@ BuildConfig build_config() {
return
BuildConfig
(
std
::
make_shared
<
BuildConfigNode
>
());
return
BuildConfig
(
std
::
make_shared
<
BuildConfigNode
>
());
}
}
/*! \brief Entry to hold the BuildConfig context stack. */
struct
TVMBuildConfigThreadLocalEntry
{
/*! \brief The default build config if the stack is empty */
tvm
::
BuildConfig
default_config
;
/*! \brief The current build config context */
std
::
stack
<
tvm
::
BuildConfig
>
context_stack
;
TVMBuildConfigThreadLocalEntry
()
:
default_config
(
build_config
())
{
}
};
/*! \brief Thread local store to hold the BuildConfig context stack. */
typedef
dmlc
::
ThreadLocalStore
<
TVMBuildConfigThreadLocalEntry
>
TVMBuildConfigThreadLocalStore
;
void
BuildConfig
::
EnterBuildConfigScope
(
const
tvm
::
BuildConfig
&
build_config
)
{
TVMBuildConfigThreadLocalEntry
*
entry
=
TVMBuildConfigThreadLocalStore
::
Get
();
entry
->
context_stack
.
push
(
build_config
);
}
void
BuildConfig
::
ExitBuildConfigScope
()
{
TVMBuildConfigThreadLocalEntry
*
entry
=
TVMBuildConfigThreadLocalStore
::
Get
();
entry
->
context_stack
.
pop
();
}
tvm
::
BuildConfig
BuildConfig
::
Current
()
{
TVMBuildConfigThreadLocalEntry
*
entry
=
TVMBuildConfigThreadLocalStore
::
Get
();
if
(
entry
->
context_stack
.
size
()
>
0
)
{
return
entry
->
context_stack
.
top
();
}
return
entry
->
default_config
;
}
TVM_REGISTER_NODE_TYPE
(
BuildConfigNode
);
TVM_REGISTER_NODE_TYPE
(
BuildConfigNode
);
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
...
@@ -482,7 +517,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
...
@@ -482,7 +517,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p
->
stream
<<
"unroll_explicit="
<<
op
->
unroll_explicit
<<
", "
;
p
->
stream
<<
"unroll_explicit="
<<
op
->
unroll_explicit
<<
", "
;
p
->
stream
<<
"restricted_func="
<<
op
->
restricted_func
<<
", "
;
p
->
stream
<<
"restricted_func="
<<
op
->
restricted_func
<<
", "
;
p
->
stream
<<
"detect_global_barrier="
<<
op
->
detect_global_barrier
<<
", "
;
p
->
stream
<<
"detect_global_barrier="
<<
op
->
detect_global_barrier
<<
", "
;
p
->
stream
<<
"partition_const_loop="
<<
op
->
partition_const_loop
;
p
->
stream
<<
"partition_const_loop="
<<
op
->
partition_const_loop
<<
", "
;
p
->
stream
<<
"dump_pass_ir="
<<
op
->
dump_pass_ir
;
p
->
stream
<<
")"
;
p
->
stream
<<
")"
;
});
});
...
@@ -571,6 +607,55 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
...
@@ -571,6 +607,55 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
func
.
CallPacked
(
args
,
ret
);
func
.
CallPacked
(
args
,
ret
);
}
}
TVM_REGISTER_API
(
"_GetCurrentBuildConfig"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
BuildConfig
::
Current
();
});
TVM_REGISTER_API
(
"_EnterBuildConfigScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
BuildConfig
target
=
args
[
0
];
BuildConfig
::
EnterBuildConfigScope
(
target
);
});
TVM_REGISTER_API
(
"_ExitBuildConfigScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
BuildConfig
::
ExitBuildConfigScope
();
});
TVM_REGISTER_API
(
"_BuildConfigSetAddLowerPass"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
BuildConfig
cfg
=
args
[
0
];
std
::
vector
<
std
::
pair
<
int
,
PackedFunc
>
>
add_lower_pass
;
CHECK_EQ
(
args
.
size
()
%
2
,
1
);
for
(
int
i
=
1
;
i
<
args
.
size
();
i
+=
2
)
{
add_lower_pass
.
push_back
(
std
::
make_pair
(
args
[
i
].
operator
int
(),
args
[
i
+
1
].
operator
tvm
::
runtime
::
PackedFunc
()));
}
cfg
->
add_lower_pass
=
add_lower_pass
;
});
TVM_REGISTER_API
(
"_BuildConfigGetAddLowerPassInfo"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
// Return one of the following:
// * Size of add_lower_pass if num_args == 1
// * Phase index of pass if args are (config, index, true)
// * Function of pass if args are (config, index, false)
BuildConfig
cfg
=
args
[
0
];
if
(
args
.
num_args
==
1
)
{
*
ret
=
static_cast
<
int64_t
>
(
cfg
->
add_lower_pass
.
size
());
}
else
{
int
index
=
args
[
1
];
bool
get_phase
=
args
[
2
];
auto
item
=
cfg
->
add_lower_pass
[
index
];
if
(
get_phase
)
{
*
ret
=
item
.
first
;
}
else
{
*
ret
=
item
.
second
;
}
}
});
TVM_REGISTER_API
(
"_GenericFuncCreate"
)
TVM_REGISTER_API
(
"_GenericFuncCreate"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
...
...
tests/python/unittest/test_pass_unroll.py
View file @
aaf7ff04
...
@@ -38,6 +38,7 @@ if __name__ == "__main__":
...
@@ -38,6 +38,7 @@ if __name__ == "__main__":
file_list
=
os
.
listdir
(
'./'
)
file_list
=
os
.
listdir
(
'./'
)
cc_file
=
end_with
(
'.cc'
)
cc_file
=
end_with
(
'.cc'
)
cc_file
=
filter
(
cc_file
,
file_list
)
cc_file
=
filter
(
cc_file
,
file_list
)
cc_file
=
[
f
for
f
in
cc_file
]
assert
len
(
cc_file
)
==
3
assert
len
(
cc_file
)
==
3
for
i
in
cc_file
:
for
i
in
cc_file
:
os
.
remove
(
i
)
os
.
remove
(
i
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment