Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
41768cf9
Commit
41768cf9
authored
Aug 05, 2017
by
Tianqi Chen
Committed by
GitHub
Aug 05, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SCHEDULE][RUNIME] Introduce pragma for additional extension hint, threadpool runtime. (#299)
parent
fd96d285
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
718 additions
and
153 deletions
+718
-153
apps/ios_rpc/tests/ios_rpc_test.py
+28
-6
apps/ios_rpc/tvmrpc/TVMRuntime.mm
+1
-0
docs/api/python/dev.rst
+1
-1
include/tvm/ir.h
+2
-0
include/tvm/runtime/c_backend_api.h
+38
-10
include/tvm/schedule.h
+14
-0
python/tvm/module.py
+1
-1
python/tvm/schedule.py
+44
-3
src/api/api_lang.cc
+6
-0
src/codegen/llvm/codegen_llvm.cc
+123
-38
src/codegen/llvm/codegen_llvm.h
+25
-8
src/op/op_util.cc
+4
-0
src/runtime/c_runtime_api.cc
+0
-60
src/runtime/module_util.h
+14
-23
src/runtime/thread_pool.cc
+346
-0
src/schedule/schedule_lang.cc
+13
-0
tests/python/unittest/test_codegen_llvm.py
+7
-2
tests/python/unittest/test_codegen_vm_basic.py
+21
-1
tests/python/unittest/test_lang_schedule.py
+17
-0
web/web_runtime.cc
+13
-0
No files found.
apps/ios_rpc/tests/ios_rpc_test.py
View file @
41768cf9
...
@@ -32,23 +32,35 @@ def test_rpc_module():
...
@@ -32,23 +32,35 @@ def test_rpc_module():
n
=
tvm
.
convert
(
1024
)
n
=
tvm
.
convert
(
1024
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
A
(
*
i
)
+
1.0
,
name
=
'B'
)
B
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
A
(
*
i
)
+
1.0
,
name
=
'B'
)
temp
=
util
.
tempdir
()
s
=
tvm
.
create_schedule
(
B
.
op
)
s
=
tvm
.
create_schedule
(
B
.
op
)
xo
,
xi
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
64
)
xo
,
xi
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
64
)
s
[
B
]
.
bind
(
xi
,
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
B
]
.
bind
(
xi
,
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
B
]
.
bind
(
xo
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
B
]
.
bind
(
xo
,
tvm
.
thread_axis
(
"blockIdx.x"
))
temp
=
util
.
tempdir
()
# Build the dynamic lib.
# Build the dynamic lib.
# If we don't want to do metal and only use cpu, just set target to be target
# If we don't want to do metal and only use cpu, just set target to be target
f
=
tvm
.
build
(
s
,
[
A
,
B
],
"metal"
,
target_host
=
target
,
name
=
"myadd"
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
"metal"
,
target_host
=
target
,
name
=
"myadd"
)
path_dso
=
temp
.
relpath
(
"dev_lib.dylib"
)
path_dso
1
=
temp
.
relpath
(
"dev_lib.dylib"
)
f
.
export_library
(
path_dso
,
xcode
.
create_dylib
,
f
.
export_library
(
path_dso
1
,
xcode
.
create_dylib
,
arch
=
arch
,
sdk
=
sdk
)
arch
=
arch
,
sdk
=
sdk
)
xcode
.
codesign
(
path_dso
)
xcode
.
codesign
(
path_dso1
)
s
=
tvm
.
create_schedule
(
B
.
op
)
xo
,
xi
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
64
)
s
[
B
]
.
parallel
(
xi
)
s
[
B
]
.
pragma
(
xo
,
"parallel_launch_point"
)
s
[
B
]
.
pragma
(
xi
,
"parallel_barrier_when_finish"
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
target
,
name
=
"myadd_cpu"
)
path_dso2
=
temp
.
relpath
(
"cpu_lib.dylib"
)
f
.
export_library
(
path_dso2
,
xcode
.
create_dylib
,
arch
=
arch
,
sdk
=
sdk
)
xcode
.
codesign
(
path_dso2
)
# Start RPC test server that contains the compiled library.
# Start RPC test server that contains the compiled library.
server
=
xcode
.
popen_test_rpc
(
proxy_host
,
proxy_port
,
key
,
server
=
xcode
.
popen_test_rpc
(
proxy_host
,
proxy_port
,
key
,
destination
=
destination
,
destination
=
destination
,
libs
=
[
path_dso
],
libs
=
[
path_dso
1
,
path_dso2
])
options
=
[
"-quiet"
])
# connect to the proxy
# connect to the proxy
remote
=
rpc
.
connect
(
proxy_host
,
proxy_port
,
key
=
key
)
remote
=
rpc
.
connect
(
proxy_host
,
proxy_port
,
key
=
key
)
ctx
=
remote
.
metal
(
0
)
ctx
=
remote
.
metal
(
0
)
...
@@ -60,5 +72,15 @@ def test_rpc_module():
...
@@ -60,5 +72,15 @@ def test_rpc_module():
cost
=
time_f
(
a
,
b
)
.
mean
cost
=
time_f
(
a
,
b
)
.
mean
print
(
'
%
g secs/op'
%
cost
)
print
(
'
%
g secs/op'
%
cost
)
np
.
testing
.
assert_equal
(
b
.
asnumpy
(),
a
.
asnumpy
()
+
1
)
np
.
testing
.
assert_equal
(
b
.
asnumpy
(),
a
.
asnumpy
()
+
1
)
# CPU
ctx
=
remote
.
cpu
(
0
)
f2
=
remote
.
load_module
(
"cpu_lib.dylib"
)
a_np
=
np
.
random
.
uniform
(
size
=
1024
)
.
astype
(
A
.
dtype
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
1024
,
dtype
=
A
.
dtype
),
ctx
)
time_f
=
f2
.
time_evaluator
(
f1
.
entry_name
,
ctx
,
number
=
10
)
cost
=
time_f
(
a
,
b
)
.
mean
print
(
'
%
g secs/op'
%
cost
)
np
.
testing
.
assert_equal
(
b
.
asnumpy
(),
a
.
asnumpy
()
+
1
)
test_rpc_module
()
test_rpc_module
()
apps/ios_rpc/tvmrpc/TVMRuntime.mm
View file @
41768cf9
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "../../src/runtime/c_runtime_api.cc"
#include "../../src/runtime/c_runtime_api.cc"
#include "../../src/runtime/cpu_device_api.cc"
#include "../../src/runtime/cpu_device_api.cc"
#include "../../src/runtime/workspace_pool.cc"
#include "../../src/runtime/workspace_pool.cc"
#include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/module_util.cc"
#include "../../src/runtime/module_util.cc"
#include "../../src/runtime/system_lib_module.cc"
#include "../../src/runtime/system_lib_module.cc"
#include "../../src/runtime/module.cc"
#include "../../src/runtime/module.cc"
...
...
docs/api/python/dev.rst
View file @
41768cf9
...
@@ -45,7 +45,7 @@ tvm.ir_pass
...
@@ -45,7 +45,7 @@ tvm.ir_pass
tvm.ir_pass.StorageFlatten
tvm.ir_pass.StorageFlatten
tvm.ir_pass.VectorizeLoop
tvm.ir_pass.VectorizeLoop
tvm.ir_pass.UnrollLoop
tvm.ir_pass.UnrollLoop
tvm.ir_pass.
Storage
Sync
tvm.ir_pass.
Thread
Sync
tvm.ir_pass.StorageRewrite
tvm.ir_pass.StorageRewrite
tvm.ir_pass.MakeAPI
tvm.ir_pass.MakeAPI
tvm.ir_pass.SplitHostDevice
tvm.ir_pass.SplitHostDevice
...
...
include/tvm/ir.h
View file @
41768cf9
...
@@ -166,6 +166,8 @@ constexpr const char* device_context_type = "device_context_type";
...
@@ -166,6 +166,8 @@ constexpr const char* device_context_type = "device_context_type";
constexpr
const
char
*
loop_scope
=
"loop_scope"
;
constexpr
const
char
*
loop_scope
=
"loop_scope"
;
/*! \brief Mark of reduce scope */
/*! \brief Mark of reduce scope */
constexpr
const
char
*
reduce_scope
=
"reduce_scope"
;
constexpr
const
char
*
reduce_scope
=
"reduce_scope"
;
/*! \brief Mark region is guarded by the pragma */
constexpr
const
char
*
pragma_scope
=
"pragma_scope"
;
/*!
/*!
* \brief Mark of prefetch scope, value=offset,
* \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope
* run prefetch of Tensor on the current loop scope
...
...
include/tvm/runtime/c_backend_api.h
View file @
41768cf9
...
@@ -66,21 +66,49 @@ TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
...
@@ -66,21 +66,49 @@ TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
TVM_DLL
int
TVMBackendFreeWorkspace
(
int
device_type
,
TVM_DLL
int
TVMBackendFreeWorkspace
(
int
device_type
,
int
device_id
,
int
device_id
,
void
*
ptr
);
void
*
ptr
);
/*!
* \brief Environment for TVM parallel task.
*/
typedef
struct
{
/*!
* \brief Auxiliary used for synchronization
*/
void
*
sync_handle
;
/*! \brief total amount of task */
int32_t
num_task
;
}
TVMParallelGroupEnv
;
/*!
* \brief The callback function to execute a parallel lambda
* \param task_id the task id of the function.
* \param penv The parallel environment backs the execution.
* \param cdata The supporting closure data.
*/
typedef
int
(
*
FTVMParallelLambda
)(
int
task_id
,
TVMParallelGroupEnv
*
penv
,
void
*
cdata
);
/*!
/*!
* \brief Backend function for running parallel
for loop
.
* \brief Backend function for running parallel
jobs
.
*
*
* \param
begin The start of iteration
.
* \param
flambda The parallel function to be launched
.
* \param
end The end of iteration
.
* \param
cdata The closure data
.
* \param
lambda The lambda function to be executed.
* \param
num_task Number of tasks to launch, can be 0, means launch
*
\param env The environment of lambda function
.
*
with all available threads
.
*
*
* \return 0 when no error is thrown, -1 when failure happens
* \return 0 when no error is thrown, -1 when failure happens
*/
*/
TVM_DLL
int
TVMBackendParallelFor
(
TVM_DLL
int
TVMBackendParallelLaunch
(
FTVMParallelLambda
flambda
,
int64_t
begin
,
void
*
cdata
,
int64_t
end
,
int
num_task
);
int
(
*
lambda
)(
int64_t
begin
,
int64_t
end
,
void
*
env
),
void
*
env
);
/*!
* \brief BSP barrrier between parallel threads
* \param task_id the task id of the function.
* \param penv The parallel environment backs the execution.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL
int
TVMBackendParallelBarrier
(
int
task_id
,
TVMParallelGroupEnv
*
penv
);
#ifdef __cplusplus
#ifdef __cplusplus
}
// TVM_EXTERN_C
}
// TVM_EXTERN_C
...
...
include/tvm/schedule.h
View file @
41768cf9
...
@@ -181,6 +181,15 @@ class Stage : public NodeRef {
...
@@ -181,6 +181,15 @@ class Stage : public NodeRef {
*/
*/
Stage
&
parallel
(
IterVar
var
);
// NOLINT(*)
Stage
&
parallel
(
IterVar
var
);
// NOLINT(*)
/*!
/*!
* \brief Annotate the iteration with pragma
*
* \param var The axis to be parallelized.
* \param pragma_type The pragma type.
*
* \return reference to self.
*/
Stage
&
pragma
(
IterVar
var
,
const
std
::
string
&
pragma_type
);
// NOLINT(*)
/*!
* \brief Fetch data in advance.
* \brief Fetch data in advance.
* \param domain the tensor to be prefetched
* \param domain the tensor to be prefetched
* \param var the iteration point at which to apply prefetching
* \param var the iteration point at which to apply prefetching
...
@@ -487,6 +496,10 @@ class IterVarAttrNode : public Node {
...
@@ -487,6 +496,10 @@ class IterVarAttrNode : public Node {
* when the axis is marked as Tensorized
* when the axis is marked as Tensorized
*/
*/
TensorIntrin
tensor_intrin
;
TensorIntrin
tensor_intrin
;
/*!
* \brief Additional pragmas, array of StringImm
*/
Array
<
Expr
>
pragmas
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"iter_type"
,
&
iter_type
);
v
->
Visit
(
"iter_type"
,
&
iter_type
);
...
@@ -494,6 +507,7 @@ class IterVarAttrNode : public Node {
...
@@ -494,6 +507,7 @@ class IterVarAttrNode : public Node {
v
->
Visit
(
"prefetch_data"
,
&
prefetch_data
);
v
->
Visit
(
"prefetch_data"
,
&
prefetch_data
);
v
->
Visit
(
"prefetch_offset"
,
&
prefetch_offset
);
v
->
Visit
(
"prefetch_offset"
,
&
prefetch_offset
);
v
->
Visit
(
"tensor_intrin"
,
&
tensor_intrin
);
v
->
Visit
(
"tensor_intrin"
,
&
tensor_intrin
);
v
->
Visit
(
"pragmas"
,
&
pragmas
);
}
}
static
constexpr
const
char
*
_type_key
=
"IterVarAttr"
;
static
constexpr
const
char
*
_type_key
=
"IterVarAttr"
;
...
...
python/tvm/module.py
View file @
41768cf9
...
@@ -78,7 +78,7 @@ class Module(ModuleBase):
...
@@ -78,7 +78,7 @@ class Module(ModuleBase):
file_name : str
file_name : str
The name of the shared library.
The name of the shared library.
fcompile : function(target, file_list,
**
kwargs), optional
fcompile : function(target, file_list, kwargs), optional
Compilation function to use create dynamic library.
Compilation function to use create dynamic library.
kwargs : dict, optiona;
kwargs : dict, optiona;
...
...
python/tvm/schedule.py
View file @
41768cf9
...
@@ -26,7 +26,7 @@ class Buffer(NodeBase):
...
@@ -26,7 +26,7 @@ class Buffer(NodeBase):
WRITE
=
2
WRITE
=
2
def
access_ptr
(
self
,
access_mask
,
ptr_type
=
"handle"
):
def
access_ptr
(
self
,
access_mask
,
ptr_type
=
"handle"
):
"""Get an access pointer to the head of buffer
"""Get an access pointer to the head of buffer
.
This is the recommended method to get buffer data
This is the recommended method to get buffer data
ptress when interacting with external functions.
ptress when interacting with external functions.
...
@@ -37,7 +37,6 @@ class Buffer(NodeBase):
...
@@ -37,7 +37,6 @@ class Buffer(NodeBase):
The access pattern MASK. Indicate whether the
The access pattern MASK. Indicate whether the
access will read or write to the data content.
access will read or write to the data content.
ptr_type : str, optional
ptr_type : str, optional
The data type of the result pointer. Do not specify
The data type of the result pointer. Do not specify
unless we want to cast pointer to specific type.
unless we want to cast pointer to specific type.
...
@@ -45,8 +44,8 @@ class Buffer(NodeBase):
...
@@ -45,8 +44,8 @@ class Buffer(NodeBase):
Examples
Examples
--------
--------
.. code-block:: python
.. code-block:: python
import tvm.schedule.Buffer
import tvm.schedule.Buffer
# Get access ptr for read
# Get access ptr for read
buffer.access_ptr("r")
buffer.access_ptr("r")
# Get access ptr for read/write with bitmask
# Get access ptr for read/write with bitmask
...
@@ -465,6 +464,48 @@ class Stage(NodeBase):
...
@@ -465,6 +464,48 @@ class Stage(NodeBase):
"""
"""
_api_internal
.
_StageParallel
(
self
,
var
)
_api_internal
.
_StageParallel
(
self
,
var
)
def
pragma
(
self
,
var
,
pragma_type
):
"""Annotate the iteration with pragma
This will translate to a pragma_scope surrounding
the corresponding loop generated.
Useful to support experimental features and extensions.
Parameters
----------
var : IterVar
The iteration to be anotated
pragma_type : str
The pragma string to be annotated
Note
----
Most pragmas are advanced/experimental features
and may subject to change. List of supported pragmas:
- **parallel_launch_point**
Specify to launch parallel threads outside the
specified iteration loop. By default the threads
launch at the point of parallel construct.
This pragma moves the launching point to even outer scope.
The threads are launched once and reused across multiple
parallel constructs as BSP style program.
- **parallel_barrier_when_finish**
Insert a synchronization barrier between working threads
after the specified loop iteration finishes.
- **parallel_stride_pattern**
Hint parallel loop to execute in strided pattern.
:code:`for (int i = task_id; i < end; i += num_task)`
"""
_api_internal
.
_StagePragma
(
self
,
var
,
pragma_type
)
def
prefetch
(
self
,
tensor
,
var
,
offset
):
def
prefetch
(
self
,
tensor
,
var
,
offset
):
"""Prefetch the specified variable
"""Prefetch the specified variable
...
...
src/api/api_lang.cc
View file @
41768cf9
...
@@ -364,6 +364,12 @@ TVM_REGISTER_API("_StageParallel")
...
@@ -364,6 +364,12 @@ TVM_REGISTER_API("_StageParallel")
.
parallel
(
args
[
1
]);
.
parallel
(
args
[
1
]);
});
});
TVM_REGISTER_API
(
"_StagePragma"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
pragma
(
args
[
1
],
args
[
2
]);
});
TVM_REGISTER_API
(
"_StagePrefetch"
)
TVM_REGISTER_API
(
"_StagePrefetch"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
args
[
0
].
operator
Stage
()
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
41768cf9
...
@@ -63,8 +63,14 @@ void CodeGenLLVM::Init(const std::string& module_name,
...
@@ -63,8 +63,14 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_shape_index_
->
getPointerTo
(),
t_tvm_shape_index_
->
getPointerTo
(),
t_int64_
});
t_int64_
});
t_tvm_value_
=
llvm
::
StructType
::
create
({
t_float64_
});
t_tvm_value_
=
llvm
::
StructType
::
create
({
t_float64_
});
ftype_tvm_par_for_lambda_
=
llvm
::
FunctionType
::
get
(
t_tvm_parallel_group_env_
=
llvm
::
StructType
::
create
({
t_int_
,
{
t_int64_
,
t_int64_
,
t_void_p_
},
false
);
t_int32_
->
getPointerTo
(),
t_int32_
});
ftype_tvm_parallel_lambda_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_int_
,
t_tvm_parallel_group_env_
->
getPointerTo
(),
t_void_p_
},
false
);
md_builder_
.
reset
(
new
llvm
::
MDBuilder
(
*
ctx
));
md_builder_
.
reset
(
new
llvm
::
MDBuilder
(
*
ctx
));
md_very_likely_branch_
=
md_very_likely_branch_
=
md_builder_
->
createBranchWeights
(
1
<<
30
,
0
);
md_builder_
->
createBranchWeights
(
1
<<
30
,
0
);
...
@@ -90,9 +96,13 @@ void CodeGenLLVM::Init(const std::string& module_name,
...
@@ -90,9 +96,13 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_func_handle_
->
getPointerTo
()},
false
);
t_tvm_func_handle_
->
getPointerTo
()},
false
);
ftype_tvm_api_set_last_error_
=
llvm
::
FunctionType
::
get
(
ftype_tvm_api_set_last_error_
=
llvm
::
FunctionType
::
get
(
t_void_
,
{
t_char_
->
getPointerTo
()},
false
);
t_void_
,
{
t_char_
->
getPointerTo
()},
false
);
ftype_tvm_parallel_for_
=
ftype_tvm_parallel_launch_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
ftype_tvm_parallel_lambda_
->
getPointerTo
(),
t_void_p_
,
t_int_
}
,
false
);
ftype_tvm_parallel_barrier_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_int
64_
,
t_int64_
,
ftype_tvm_par_for_lambda_
->
getPointerTo
(),
t_void_p_
}
t_int
_
,
t_tvm_parallel_group_env_
->
getPointerTo
()
}
,
false
);
,
false
);
// initialize TVM runtime API
// initialize TVM runtime API
if
(
system_lib
)
{
if
(
system_lib
)
{
...
@@ -113,9 +123,12 @@ void CodeGenLLVM::Init(const std::string& module_name,
...
@@ -113,9 +123,12 @@ void CodeGenLLVM::Init(const std::string& module_name,
f_tvm_api_set_last_error_
=
llvm
::
Function
::
Create
(
f_tvm_api_set_last_error_
=
llvm
::
Function
::
Create
(
ftype_tvm_api_set_last_error_
,
ftype_tvm_api_set_last_error_
,
llvm
::
Function
::
ExternalLinkage
,
"TVMAPISetLastError"
,
module_
.
get
());
llvm
::
Function
::
ExternalLinkage
,
"TVMAPISetLastError"
,
module_
.
get
());
f_tvm_parallel_for_
=
llvm
::
Function
::
Create
(
f_tvm_parallel_launch_
=
llvm
::
Function
::
Create
(
ftype_tvm_parallel_for_
,
ftype_tvm_parallel_launch_
,
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendParallelFor"
,
module_
.
get
());
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendParallelLaunch"
,
module_
.
get
());
f_tvm_parallel_barrier_
=
llvm
::
Function
::
Create
(
ftype_tvm_parallel_barrier_
,
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendParallelBarrier"
,
module_
.
get
());
}
}
this
->
InitTarget
(
tm
);
this
->
InitTarget
(
tm
);
// initialize builder
// initialize builder
...
@@ -179,8 +192,10 @@ void CodeGenLLVM::InitGlobalContext(bool dynamic_lookup) {
...
@@ -179,8 +192,10 @@ void CodeGenLLVM::InitGlobalContext(bool dynamic_lookup) {
ftype_tvm_get_func_from_env_
->
getPointerTo
(),
"__TVMBackendGetFuncFromEnv"
);
ftype_tvm_get_func_from_env_
->
getPointerTo
(),
"__TVMBackendGetFuncFromEnv"
);
gv_tvm_api_set_last_error_
=
InitContextPtr
(
gv_tvm_api_set_last_error_
=
InitContextPtr
(
ftype_tvm_api_set_last_error_
->
getPointerTo
(),
"__TVMAPISetLastError"
);
ftype_tvm_api_set_last_error_
->
getPointerTo
(),
"__TVMAPISetLastError"
);
gv_tvm_parallel_for_
=
InitContextPtr
(
gv_tvm_parallel_launch_
=
InitContextPtr
(
ftype_tvm_parallel_for_
->
getPointerTo
(),
"__TVMBackendParallelFor"
);
ftype_tvm_parallel_launch_
->
getPointerTo
(),
"__TVMBackendParallelLaunch"
);
gv_tvm_parallel_barrier_
=
InitContextPtr
(
ftype_tvm_parallel_barrier_
->
getPointerTo
(),
"__TVMBackendParallelBarrier"
);
// Mark as context functions
// Mark as context functions
gv_func_map_
[
"TVMBackendAllocWorkspace"
]
=
nullptr
;
gv_func_map_
[
"TVMBackendAllocWorkspace"
]
=
nullptr
;
gv_func_map_
[
"TVMBackendFreeWorkspace"
]
=
nullptr
;
gv_func_map_
[
"TVMBackendFreeWorkspace"
]
=
nullptr
;
...
@@ -702,9 +717,14 @@ llvm::Value* CodeGenLLVM::RuntimeTVMAPISetLastError() {
...
@@ -702,9 +717,14 @@ llvm::Value* CodeGenLLVM::RuntimeTVMAPISetLastError() {
if
(
f_tvm_api_set_last_error_
!=
nullptr
)
return
f_tvm_api_set_last_error_
;
if
(
f_tvm_api_set_last_error_
!=
nullptr
)
return
f_tvm_api_set_last_error_
;
return
GetContextPtr
(
gv_tvm_api_set_last_error_
);
return
GetContextPtr
(
gv_tvm_api_set_last_error_
);
}
}
llvm
::
Value
*
CodeGenLLVM
::
RuntimeTVMParallelFor
()
{
llvm
::
Value
*
CodeGenLLVM
::
RuntimeTVMParallelLaunch
()
{
if
(
f_tvm_parallel_for_
!=
nullptr
)
return
f_tvm_parallel_for_
;
if
(
f_tvm_parallel_launch_
!=
nullptr
)
return
f_tvm_parallel_launch_
;
return
GetContextPtr
(
gv_tvm_parallel_for_
);
return
GetContextPtr
(
gv_tvm_parallel_launch_
);
}
llvm
::
Value
*
CodeGenLLVM
::
RuntimeTVMParallelBarrier
()
{
if
(
f_tvm_parallel_barrier_
!=
nullptr
)
return
f_tvm_parallel_barrier_
;
return
GetContextPtr
(
gv_tvm_parallel_barrier_
);
}
}
llvm
::
Value
*
CodeGenLLVM
::
GetVarValue
(
const
Variable
*
v
)
const
{
llvm
::
Value
*
CodeGenLLVM
::
GetVarValue
(
const
Variable
*
v
)
const
{
...
@@ -782,15 +802,9 @@ void CodeGenLLVM::CreateComputeScope(const AttrStmt* op) {
...
@@ -782,15 +802,9 @@ void CodeGenLLVM::CreateComputeScope(const AttrStmt* op) {
builder_
->
SetInsertPoint
(
compute_call_end
);
builder_
->
SetInsertPoint
(
compute_call_end
);
}
}
void
CodeGenLLVM
::
CreateParallel
For
(
const
For
*
op
)
{
void
CodeGenLLVM
::
CreateParallel
Launch
(
const
Stmt
&
body
,
int
num_task
)
{
using
llvm
::
BasicBlock
;
using
llvm
::
BasicBlock
;
llvm
::
Value
*
min
=
MakeValue
(
op
->
min
);
Array
<
Var
>
vfields
=
ir
::
UndefinedVars
(
body
,
{});
llvm
::
Value
*
extent
=
MakeValue
(
op
->
extent
);
min
=
builder_
->
CreateIntCast
(
min
,
t_int64_
,
op
->
min
.
type
().
is_int
());
extent
=
builder_
->
CreateIntCast
(
extent
,
t_int64_
,
op
->
min
.
type
().
is_int
());
// fields to be packed into closure.
Var
loop_var
(
op
->
loop_var
.
node_
);
Array
<
Var
>
vfields
=
ir
::
UndefinedVars
(
op
->
body
,
{
loop_var
});
std
::
vector
<
llvm
::
Type
*>
fields
;
std
::
vector
<
llvm
::
Type
*>
fields
;
for
(
Var
v
:
vfields
)
{
for
(
Var
v
:
vfields
)
{
auto
it
=
var_map_
.
find
(
v
.
get
());
auto
it
=
var_map_
.
find
(
v
.
get
());
...
@@ -800,9 +814,9 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
...
@@ -800,9 +814,9 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
// closure data
// closure data
llvm
::
StructType
*
tcdata
=
llvm
::
StructType
::
create
(
fields
);
llvm
::
StructType
*
tcdata
=
llvm
::
StructType
::
create
(
fields
);
llvm
::
Function
*
f
=
llvm
::
Function
::
Create
(
llvm
::
Function
*
f
=
llvm
::
Function
::
Create
(
ftype_tvm_par
_for
_lambda_
,
ftype_tvm_par
allel
_lambda_
,
llvm
::
Function
::
PrivateLinkage
,
llvm
::
Function
::
PrivateLinkage
,
"__tvm_par
_for
_lambda"
,
module_
.
get
());
"__tvm_par
allel
_lambda"
,
module_
.
get
());
// allocate and setup the closure, call the closure.
// allocate and setup the closure, call the closure.
llvm
::
Value
*
cdata
=
builder_
->
CreateAlloca
(
tcdata
,
ConstInt32
(
1
));
llvm
::
Value
*
cdata
=
builder_
->
CreateAlloca
(
tcdata
,
ConstInt32
(
1
));
llvm
::
Value
*
zero
=
ConstInt32
(
0
);
llvm
::
Value
*
zero
=
ConstInt32
(
0
);
...
@@ -812,19 +826,17 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
...
@@ -812,19 +826,17 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
var_map_
.
at
(
vfields
[
i
].
get
()),
var_map_
.
at
(
vfields
[
i
].
get
()),
builder_
->
CreateInBoundsGEP
(
cdata
,
{
zero
,
ConstInt32
(
i
)}));
builder_
->
CreateInBoundsGEP
(
cdata
,
{
zero
,
ConstInt32
(
i
)}));
}
}
BasicBlock
*
par_
for
_end
=
CheckCallSuccess
(
BasicBlock
*
par_
launch
_end
=
CheckCallSuccess
(
builder_
->
CreateCall
(
builder_
->
CreateCall
(
RuntimeTVMParallel
For
(),
RuntimeTVMParallel
Launch
(),
{
min
,
extent
,
f
,
builder_
->
CreatePointerCast
(
cdata
,
t_void_p_
)}));
{
f
,
builder_
->
CreatePointerCast
(
cdata
,
t_void_p_
),
ConstInt32
(
num_task
)}));
// Setup the closure function.
// Setup the closure function.
BasicBlock
*
lambda_entry
=
BasicBlock
::
Create
(
*
ctx_
,
"entry"
,
f
);
BasicBlock
*
lambda_entry
=
BasicBlock
::
Create
(
*
ctx_
,
"entry"
,
f
);
builder_
->
SetInsertPoint
(
lambda_entry
);
builder_
->
SetInsertPoint
(
lambda_entry
);
auto
it
=
f
->
arg_begin
();
auto
it
=
f
->
arg_begin
();
llvm
::
Value
*
begin
=
&
(
*
it
++
);
llvm
::
Value
*
task_id
=
&
(
*
it
++
);
llvm
::
Value
*
end
=
&
(
*
it
++
);
llvm
::
Value
*
penv
=
&
(
*
it
++
);
cdata
=
&
(
*
it
++
);
cdata
=
&
(
*
it
++
);
begin
=
CreateCast
(
Int
(
64
),
op
->
loop_var
.
type
(),
begin
);
end
=
CreateCast
(
Int
(
64
),
op
->
loop_var
.
type
(),
end
);
cdata
=
builder_
->
CreatePointerCast
(
cdata
,
tcdata
->
getPointerTo
());
cdata
=
builder_
->
CreatePointerCast
(
cdata
,
tcdata
->
getPointerTo
());
// setup new variable map, swap it with current var context.
// setup new variable map, swap it with current var context.
std
::
unordered_map
<
const
Variable
*
,
llvm
::
Value
*>
new_vmap
;
std
::
unordered_map
<
const
Variable
*
,
llvm
::
Value
*>
new_vmap
;
...
@@ -833,17 +845,32 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
...
@@ -833,17 +845,32 @@ void CodeGenLLVM::CreateParallelFor(const For* op) {
builder_
->
CreateLoad
(
builder_
->
CreateInBoundsGEP
(
builder_
->
CreateLoad
(
builder_
->
CreateInBoundsGEP
(
cdata
,
{
zero
,
ConstInt32
(
i
)}));
cdata
,
{
zero
,
ConstInt32
(
i
)}));
}
}
// setup parallel env
ParallelEnv
par_env
;
par_env
.
task_id
=
Var
(
"task_id"
,
Int
(
32
));
par_env
.
num_task
=
Var
(
"num_task"
,
Int
(
32
));
new_vmap
[
par_env
.
task_id
.
get
()]
=
task_id
;
new_vmap
[
par_env
.
num_task
.
get
()]
=
builder_
->
CreateLoad
(
builder_
->
CreateInBoundsGEP
(
penv
,
{
zero
,
ConstInt32
(
1
)}));
par_env
.
penv
=
penv
;
std
::
swap
(
function_
,
f
);
std
::
swap
(
function_
,
f
);
std
::
swap
(
new_vmap
,
var_map_
);
std
::
swap
(
parallel_env_
,
par_env
);
CreateSerialFor
(
begin
,
end
,
op
->
loop_var
,
op
->
body
);
std
::
swap
(
var_map_
,
new_vmap
);
this
->
VisitStmt
(
body
);
builder_
->
CreateRet
(
ConstInt32
(
0
));
builder_
->
CreateRet
(
ConstInt32
(
0
));
// swap the var map back, now we are back on track.
// swap the var map back, now we are back on track.
std
::
swap
(
new_vmap
,
var_map_
);
std
::
swap
(
var_map_
,
new_vmap
);
std
::
swap
(
parallel_env_
,
par_env
);
std
::
swap
(
function_
,
f
);
std
::
swap
(
function_
,
f
);
builder_
->
SetInsertPoint
(
par_for_end
);
CHECK
(
par_env
.
hit_parallel_loop
)
<<
"Cannot find parallel loop within parallel launch"
;
builder_
->
SetInsertPoint
(
par_launch_end
);
}
}
void
CodeGenLLVM
::
CreateSerialFor
(
llvm
::
Value
*
begin
,
llvm
::
Value
*
end
,
void
CodeGenLLVM
::
CreateSerialFor
(
llvm
::
Value
*
begin
,
llvm
::
Value
*
end
,
llvm
::
Value
*
stride
,
const
VarExpr
&
loop_var
,
const
Stmt
&
body
)
{
const
VarExpr
&
loop_var
,
const
Stmt
&
body
)
{
using
llvm
::
BasicBlock
;
using
llvm
::
BasicBlock
;
Type
t
=
loop_var
.
type
();
Type
t
=
loop_var
.
type
();
...
@@ -864,7 +891,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
...
@@ -864,7 +891,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
builder_
->
SetInsertPoint
(
for_body
);
builder_
->
SetInsertPoint
(
for_body
);
var_map_
[
loop_var
.
get
()]
=
index
;
var_map_
[
loop_var
.
get
()]
=
index
;
this
->
VisitStmt
(
body
);
this
->
VisitStmt
(
body
);
llvm
::
Value
*
next_index
=
CreateAdd
(
t
,
index
,
ConstInt32
(
1
)
);
llvm
::
Value
*
next_index
=
CreateAdd
(
t
,
index
,
stride
);
index
->
addIncoming
(
next_index
,
builder_
->
GetInsertBlock
());
index
->
addIncoming
(
next_index
,
builder_
->
GetInsertBlock
());
builder_
->
CreateBr
(
for_head
);
builder_
->
CreateBr
(
for_head
);
// end of for
// end of for
...
@@ -1481,10 +1508,45 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
...
@@ -1481,10 +1508,45 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
void
CodeGenLLVM
::
VisitStmt_
(
const
For
*
op
)
{
void
CodeGenLLVM
::
VisitStmt_
(
const
For
*
op
)
{
CHECK
(
is_zero
(
op
->
min
));
CHECK
(
is_zero
(
op
->
min
));
if
(
op
->
for_type
==
ForType
::
Serial
)
{
if
(
op
->
for_type
==
ForType
::
Serial
)
{
CreateSerialFor
(
ConstInt32
(
0
),
MakeValue
(
op
->
extent
),
CreateSerialFor
(
ConstInt32
(
0
),
op
->
loop_var
,
op
->
body
);
MakeValue
(
op
->
extent
),
ConstInt32
(
1
),
op
->
loop_var
,
op
->
body
);
}
else
if
(
op
->
for_type
==
ForType
::
Parallel
)
{
}
else
if
(
op
->
for_type
==
ForType
::
Parallel
)
{
CreateParallelFor
(
op
);
if
(
parallel_env_
.
penv
==
nullptr
)
{
CreateParallelLaunch
(
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
op
->
for_type
,
op
->
device_api
,
op
->
body
),
0
);
}
else
{
// already in parallel env.
CHECK
(
parallel_env_
.
task_id
.
defined
());
CHECK
(
parallel_env_
.
num_task
.
defined
());
CHECK
(
parallel_env_
.
penv
!=
nullptr
);
Type
t
=
op
->
extent
.
type
();
Expr
num_task
=
cast
(
t
,
parallel_env_
.
num_task
);
Expr
task_id
=
cast
(
t
,
parallel_env_
.
task_id
);
CHECK
(
!
parallel_env_
.
hit_parallel_loop
)
<<
"Nested parallel loop is not supported by threadpool, try fuse them instead"
;
parallel_env_
.
hit_parallel_loop
=
true
;
if
(
parallel_env_
.
stride_pattern
)
{
CreateSerialFor
(
MakeValue
(
task_id
),
MakeValue
(
op
->
extent
),
MakeValue
(
num_task
),
op
->
loop_var
,
op
->
body
);
}
else
{
Expr
step
=
(
op
->
extent
+
num_task
-
make_const
(
t
,
1
))
/
num_task
;
Expr
begin
=
Min
::
make
(
task_id
*
step
,
op
->
extent
);
Expr
end
=
Min
::
make
((
task_id
+
make_const
(
t
,
1
))
*
step
,
op
->
extent
);
CreateSerialFor
(
MakeValue
(
begin
),
MakeValue
(
end
),
ConstInt32
(
1
),
op
->
loop_var
,
op
->
body
);
}
}
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"cannot handle for type "
<<
op
->
for_type
;
LOG
(
FATAL
)
<<
"cannot handle for type "
<<
op
->
for_type
;
}
}
...
@@ -1566,6 +1628,29 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
...
@@ -1566,6 +1628,29 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
this
->
VisitStmt
(
op
->
body
);
this
->
VisitStmt
(
op
->
body
);
}
else
if
(
op
->
attr_key
==
ir
::
attr
::
compute_scope
)
{
}
else
if
(
op
->
attr_key
==
ir
::
attr
::
compute_scope
)
{
this
->
CreateComputeScope
(
op
);
this
->
CreateComputeScope
(
op
);
}
else
if
(
op
->
attr_key
==
ir
::
attr
::
pragma_scope
)
{
const
std
::
string
&
pname
=
op
->
value
.
as
<
StringImm
>
()
->
value
;
if
(
pname
==
"parallel_stride_pattern"
)
{
CHECK
(
parallel_env_
.
penv
!=
nullptr
)
<<
"Pragma parallel_stride_pattern only valid in parallel launch"
;
parallel_env_
.
stride_pattern
=
true
;
this
->
VisitStmt
(
op
->
body
);
}
else
if
(
pname
==
"parallel_launch_point"
)
{
CreateParallelLaunch
(
op
->
body
,
0
);
}
else
if
(
pname
==
"parallel_barrier_when_finish"
)
{
CHECK
(
parallel_env_
.
penv
!=
nullptr
)
<<
"Cannot run barrier without parallel environment"
;
CHECK
(
!
parallel_env_
.
hit_parallel_loop
)
<<
"Cannot not place within parallel loop as the workload may differ, "
<<
" place it between parallel and parallel_launch_point"
;
this
->
VisitStmt
(
op
->
body
);
builder_
->
CreateCall
(
RuntimeTVMParallelBarrier
(),
{
MakeValue
(
parallel_env_
.
task_id
),
parallel_env_
.
penv
});
}
else
{
LOG
(
WARNING
)
<<
"Unknown pragma "
<<
pname
;
this
->
VisitStmt
(
op
->
body
);
}
}
else
{
}
else
{
this
->
VisitStmt
(
op
->
body
);
this
->
VisitStmt
(
op
->
body
);
}
}
...
...
src/codegen/llvm/codegen_llvm.h
View file @
41768cf9
...
@@ -189,11 +189,13 @@ class CodeGenLLVM :
...
@@ -189,11 +189,13 @@ class CodeGenLLVM :
llvm
::
StructType
*
t_tvm_type_
{
nullptr
};
llvm
::
StructType
*
t_tvm_type_
{
nullptr
};
llvm
::
StructType
*
t_tvm_array_
{
nullptr
};
llvm
::
StructType
*
t_tvm_array_
{
nullptr
};
llvm
::
StructType
*
t_tvm_value_
{
nullptr
};
llvm
::
StructType
*
t_tvm_value_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_par_for_lambda_
{
nullptr
};
llvm
::
StructType
*
t_tvm_parallel_group_env_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_parallel_lambda_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_func_call_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_func_call_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_get_func_from_env_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_get_func_from_env_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_api_set_last_error_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_api_set_last_error_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_parallel_for_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_parallel_launch_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_parallel_barrier_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_register_system_symbol_
{
nullptr
};
llvm
::
FunctionType
*
ftype_tvm_register_system_symbol_
{
nullptr
};
// The acting body
// The acting body
llvm
::
BasicBlock
*
block_
{
nullptr
};
llvm
::
BasicBlock
*
block_
{
nullptr
};
...
@@ -203,13 +205,22 @@ class CodeGenLLVM :
...
@@ -203,13 +205,22 @@ class CodeGenLLVM :
std
::
unordered_map
<
const
Variable
*
,
StorageInfo
>
alloc_storage_info_
;
std
::
unordered_map
<
const
Variable
*
,
StorageInfo
>
alloc_storage_info_
;
private
:
private
:
// the parallel group information
struct
ParallelEnv
{
VarExpr
task_id
;
VarExpr
num_task
;
bool
stride_pattern
{
false
};
bool
hit_parallel_loop
{
false
};
llvm
::
Value
*
penv
{
nullptr
};
};
// Get runtime functions
// Get runtime functions
llvm
::
GlobalVariable
*
InitContextPtr
(
llvm
::
Type
*
type
,
std
::
string
name
);
llvm
::
GlobalVariable
*
InitContextPtr
(
llvm
::
Type
*
type
,
std
::
string
name
);
llvm
::
Value
*
GetContextPtr
(
llvm
::
GlobalVariable
*
gv
);
llvm
::
Value
*
GetContextPtr
(
llvm
::
GlobalVariable
*
gv
);
llvm
::
Value
*
RuntimeTVMFuncCall
();
llvm
::
Value
*
RuntimeTVMFuncCall
();
llvm
::
Value
*
RuntimeTVMGetFuncFromEnv
();
llvm
::
Value
*
RuntimeTVMGetFuncFromEnv
();
llvm
::
Value
*
RuntimeTVMAPISetLastError
();
llvm
::
Value
*
RuntimeTVMAPISetLastError
();
llvm
::
Value
*
RuntimeTVMParallelFor
();
llvm
::
Value
*
RuntimeTVMParallelLaunch
();
llvm
::
Value
*
RuntimeTVMParallelBarrier
();
// comparison op
// comparison op
llvm
::
Value
*
GetVarValue
(
const
Variable
*
v
)
const
;
llvm
::
Value
*
GetVarValue
(
const
Variable
*
v
)
const
;
llvm
::
Value
*
CreateLT
(
Type
t
,
llvm
::
Value
*
a
,
llvm
::
Value
*
b
);
llvm
::
Value
*
CreateLT
(
Type
t
,
llvm
::
Value
*
a
,
llvm
::
Value
*
b
);
...
@@ -230,10 +241,12 @@ class CodeGenLLVM :
...
@@ -230,10 +241,12 @@ class CodeGenLLVM :
llvm
::
Value
*
CreateVecFlip
(
llvm
::
Value
*
vec
);
llvm
::
Value
*
CreateVecFlip
(
llvm
::
Value
*
vec
);
llvm
::
Value
*
CreateVecConcat
(
std
::
vector
<
llvm
::
Value
*>
vecs
);
llvm
::
Value
*
CreateVecConcat
(
std
::
vector
<
llvm
::
Value
*>
vecs
);
llvm
::
Value
*
CreateVecPad
(
llvm
::
Value
*
vec
,
int
target_lanes
);
llvm
::
Value
*
CreateVecPad
(
llvm
::
Value
*
vec
,
int
target_lanes
);
// Create parallel
for.
// Create parallel
launch
void
CreateParallel
For
(
const
For
*
op
);
void
CreateParallel
Launch
(
const
Stmt
&
body
,
int
num_task
);
// Create serial for
// Create serial for
void
CreateSerialFor
(
llvm
::
Value
*
begin
,
llvm
::
Value
*
end
,
void
CreateSerialFor
(
llvm
::
Value
*
begin
,
llvm
::
Value
*
end
,
llvm
::
Value
*
stride
,
const
VarExpr
&
loop_var
,
const
Stmt
&
body
);
const
VarExpr
&
loop_var
,
const
Stmt
&
body
);
// Create a new compute scope.
// Create a new compute scope.
void
CreateComputeScope
(
const
AttrStmt
*
op
);
void
CreateComputeScope
(
const
AttrStmt
*
op
);
...
@@ -262,14 +275,18 @@ class CodeGenLLVM :
...
@@ -262,14 +275,18 @@ class CodeGenLLVM :
llvm
::
GlobalVariable
*
gv_tvm_func_call_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_func_call_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_get_func_from_env_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_get_func_from_env_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_api_set_last_error_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_api_set_last_error_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_parallel_for_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_parallel_launch_
{
nullptr
};
llvm
::
GlobalVariable
*
gv_tvm_parallel_barrier_
{
nullptr
};
std
::
unordered_map
<
std
::
string
,
llvm
::
GlobalVariable
*>
gv_func_map_
;
std
::
unordered_map
<
std
::
string
,
llvm
::
GlobalVariable
*>
gv_func_map_
;
// context for direct dynamic lookup
// context for direct dynamic lookup
llvm
::
Function
*
f_tvm_func_call_
{
nullptr
};
llvm
::
Function
*
f_tvm_func_call_
{
nullptr
};
llvm
::
Function
*
f_tvm_get_func_from_env_
{
nullptr
};
llvm
::
Function
*
f_tvm_get_func_from_env_
{
nullptr
};
llvm
::
Function
*
f_tvm_api_set_last_error_
{
nullptr
};
llvm
::
Function
*
f_tvm_api_set_last_error_
{
nullptr
};
llvm
::
Function
*
f_tvm_parallel_for_
{
nullptr
};
llvm
::
Function
*
f_tvm_parallel_launch_
{
nullptr
};
llvm
::
Function
*
f_tvm_parallel_barrier_
{
nullptr
};
llvm
::
Function
*
f_tvm_register_system_symbol_
{
nullptr
};
llvm
::
Function
*
f_tvm_register_system_symbol_
{
nullptr
};
// Current parallel environment scope.
ParallelEnv
parallel_env_
;
// global to packed function handle
// global to packed function handle
std
::
unordered_map
<
std
::
string
,
llvm
::
GlobalVariable
*>
func_handle_map_
;
std
::
unordered_map
<
std
::
string
,
llvm
::
GlobalVariable
*>
func_handle_map_
;
// List of symbols to be exported to TVM system lib.
// List of symbols to be exported to TVM system lib.
...
...
src/op/op_util.cc
View file @
41768cf9
...
@@ -70,6 +70,10 @@ MakeLoopNest(const Stage& stage,
...
@@ -70,6 +70,10 @@ MakeLoopNest(const Stage& stage,
<<
it_attr
->
iter_type
<<
it_attr
->
iter_type
<<
" in the iter_var_attrs"
;
<<
" in the iter_var_attrs"
;
}
}
for
(
Expr
p
:
it_attr
->
pragmas
)
{
nest
[
i
+
1
].
emplace_back
(
AttrStmt
::
make
(
iv
,
ir
::
attr
::
pragma_scope
,
p
,
no_op
));
}
}
}
if
(
is_one
(
dom
->
extent
))
{
if
(
is_one
(
dom
->
extent
))
{
nest
[
i
+
1
].
emplace_back
(
nest
[
i
+
1
].
emplace_back
(
...
...
src/runtime/c_runtime_api.cc
View file @
41768cf9
...
@@ -14,8 +14,6 @@
...
@@ -14,8 +14,6 @@
#include <algorithm>
#include <algorithm>
#include <string>
#include <string>
#include <cstdlib>
#include <cstdlib>
#include <thread>
#include <mutex>
#include "./runtime_base.h"
#include "./runtime_base.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -158,24 +156,6 @@ struct TVMRuntimeEntry {
...
@@ -158,24 +156,6 @@ struct TVMRuntimeEntry {
std
::
string
ret_str
;
std
::
string
ret_str
;
std
::
string
last_error
;
std
::
string
last_error
;
TVMByteArray
ret_bytes
;
TVMByteArray
ret_bytes
;
// threads used in parallel for
std
::
vector
<
std
::
thread
>
par_threads
;
// errors created in parallel for.
std
::
vector
<
std
::
string
>
par_errors
;
// number of parallel threads
int
num_par_threads
{
1
};
TVMRuntimeEntry
()
{
const
char
*
val
=
getenv
(
"TVM_NUM_THREADS"
);
if
(
val
==
nullptr
)
{
val
=
getenv
(
"OMP_NUM_THREADS"
);
}
if
(
val
!=
nullptr
)
{
num_par_threads
=
atoi
(
val
);
}
else
{
num_par_threads
=
std
::
thread
::
hardware_concurrency
()
/
2
;
}
}
};
};
typedef
dmlc
::
ThreadLocalStore
<
TVMRuntimeEntry
>
TVMAPIRuntimeStore
;
typedef
dmlc
::
ThreadLocalStore
<
TVMRuntimeEntry
>
TVMAPIRuntimeStore
;
...
@@ -254,46 +234,6 @@ int TVMBackendFreeWorkspace(int device_type,
...
@@ -254,46 +234,6 @@ int TVMBackendFreeWorkspace(int device_type,
return
0
;
return
0
;
}
}
int
TVMBackendParallelFor
(
int64_t
begin
,
int64_t
end
,
int
(
*
lambda
)(
int64_t
begin
,
int64_t
end
,
void
*
env
),
void
*
env
)
{
TVMRuntimeEntry
*
rt
=
TVMAPIRuntimeStore
::
Get
();
int
nthread
=
rt
->
num_par_threads
;
rt
->
par_threads
.
resize
(
nthread
);
rt
->
par_errors
.
clear
();
rt
->
par_errors
.
resize
(
nthread
);
int64_t
step
=
(
end
-
begin
+
nthread
-
1
)
/
nthread
;
auto
fexec
=
[
lambda
,
env
,
begin
,
end
,
step
,
rt
](
int
i
)
{
int64_t
ibegin
=
std
::
min
(
end
,
begin
+
step
*
i
);
int64_t
iend
=
std
::
min
(
end
,
begin
+
step
*
(
i
+
1
));
int
rv
=
(
*
lambda
)(
ibegin
,
iend
,
env
);
if
(
rv
!=
0
)
{
std
::
ostringstream
os
;
os
<<
"Thread "
<<
i
<<
" error:"
<<
TVMGetLastError
();
rt
->
par_errors
[
i
]
=
os
.
str
();
}
};
for
(
int
i
=
0
;
i
<
nthread
;
++
i
)
{
rt
->
par_threads
[
i
]
=
std
::
thread
(
fexec
,
i
);
}
int
ret
=
0
;
for
(
int
i
=
0
;
i
<
nthread
;
++
i
)
{
rt
->
par_threads
[
i
].
join
();
if
(
rt
->
par_errors
[
i
].
length
()
!=
0
)
ret
=
-
1
;
}
if
(
ret
==
0
)
return
ret
;
std
::
ostringstream
os
;
for
(
int
i
=
0
;
i
<
nthread
;
++
i
)
{
if
(
rt
->
par_errors
[
i
].
length
()
!=
0
)
{
os
<<
rt
->
par_errors
[
i
]
<<
'\n'
;
}
}
rt
->
last_error
=
os
.
str
();
return
-
1
;
}
int
TVMFuncFree
(
TVMFunctionHandle
func
)
{
int
TVMFuncFree
(
TVMFunctionHandle
func
)
{
API_BEGIN
();
API_BEGIN
();
delete
static_cast
<
PackedFunc
*>
(
func
);
delete
static_cast
<
PackedFunc
*>
(
func
);
...
...
src/runtime/module_util.h
View file @
41768cf9
...
@@ -40,30 +40,21 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
...
@@ -40,30 +40,21 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* module_list);
*/
*/
template
<
typename
FLookup
>
template
<
typename
FLookup
>
void
InitContextFunctions
(
FLookup
flookup
)
{
void
InitContextFunctions
(
FLookup
flookup
)
{
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMFuncCall
)
*>
#define TVM_INIT_CONTEXT_FUNC(FuncName) \
(
flookup
(
"__TVMFuncCall"
)))
{
if (auto *fp = reinterpret_cast<decltype(&FuncName)*> \
*
fp
=
TVMFuncCall
;
(flookup("__" #FuncName))) { \
}
*fp = FuncName; \
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMAPISetLastError
)
*>
(
flookup
(
"__TVMAPISetLastError"
)))
{
*
fp
=
TVMAPISetLastError
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMBackendGetFuncFromEnv
)
*>
(
flookup
(
"__TVMBackendGetFuncFromEnv"
)))
{
*
fp
=
TVMBackendGetFuncFromEnv
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMBackendAllocWorkspace
)
*>
(
flookup
(
"__TVMBackendAllocWorkspace"
)))
{
*
fp
=
TVMBackendAllocWorkspace
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMBackendFreeWorkspace
)
*>
(
flookup
(
"__TVMBackendFreeWorkspace"
)))
{
*
fp
=
TVMBackendFreeWorkspace
;
}
if
(
auto
*
fp
=
reinterpret_cast
<
decltype
(
&
TVMBackendParallelFor
)
*>
(
flookup
(
"__TVMBackendParallelFor"
)))
{
*
fp
=
TVMBackendParallelFor
;
}
}
// Initialize the functions
TVM_INIT_CONTEXT_FUNC
(
TVMFuncCall
);
TVM_INIT_CONTEXT_FUNC
(
TVMAPISetLastError
);
TVM_INIT_CONTEXT_FUNC
(
TVMBackendGetFuncFromEnv
);
TVM_INIT_CONTEXT_FUNC
(
TVMBackendAllocWorkspace
);
TVM_INIT_CONTEXT_FUNC
(
TVMBackendFreeWorkspace
);
TVM_INIT_CONTEXT_FUNC
(
TVMBackendParallelLaunch
);
TVM_INIT_CONTEXT_FUNC
(
TVMBackendParallelBarrier
);
#undef TVM_INIT_CONTEXT_FUNC
}
}
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
...
...
src/runtime/thread_pool.cc
0 → 100644
View file @
41768cf9
/*!
* Copyright (c) 2017 by Contributors
* \file thread_pool.cc
* \brief Threadpool for multi-threading runtime.
*/
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_backend_api.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <thread>
#include <condition_variable>
#include <mutex>
#include <atomic>
#include <vector>
#include <string>
#include <cstring>
#include <memory>
#include <sstream>
namespace
tvm
{
namespace
runtime
{
// stride in the page, fit to cache line.
constexpr
int
kSyncStride
=
64
/
sizeof
(
std
::
atomic
<
int
>
);
/*!
* \brief Thread local master environment.
*/
class
ParallelLauncher
{
public
:
// Reset the the task request.
void
Init
(
FTVMParallelLambda
flambda
,
void
*
cdata
,
int
num_task
,
bool
need_sync
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
num_pending_
=
num_task
;
this
->
cdata
=
cdata
;
this
->
flambda
=
flambda
;
this
->
env
.
num_task
=
num_task
;
has_error_
=
false
;
// reshape
if
(
static_cast
<
size_t
>
(
num_task
)
>
par_errors_
.
size
())
{
par_errors_
.
resize
(
num_task
+
1
);
if
(
need_sync
)
{
delete
[]
sync_counter_
;
sync_counter_
=
new
std
::
atomic
<
int
>
[
num_task
*
kSyncStride
];
}
}
if
(
need_sync
)
{
for
(
int
i
=
0
;
i
<
num_task
;
++
i
)
{
sync_counter_
[
i
*
kSyncStride
].
store
(
0
,
std
::
memory_order_relaxed
);
}
this
->
env
.
sync_handle
=
sync_counter_
;
}
else
{
this
->
env
.
sync_handle
=
nullptr
;
}
}
~
ParallelLauncher
()
{
delete
[]
sync_counter_
;
}
// Wait n jobs to finish
int
WaitForJobs
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
num_pending_
==
0
;
});
if
(
!
has_error_
)
return
0
;
std
::
ostringstream
os
;
for
(
size_t
i
=
0
;
i
<
par_errors_
.
size
();
++
i
)
{
if
(
par_errors_
[
i
].
length
()
!=
0
)
{
os
<<
"Task "
<<
i
<<
" error: "
<<
par_errors_
[
i
]
<<
'\n'
;
par_errors_
[
i
].
clear
();
}
}
TVMAPISetLastError
(
os
.
str
().
c_str
());
return
-
1
;
}
// Signal that one job has finished.
void
SignalJobError
(
int
task_id
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
--
num_pending_
;
par_errors_
[
task_id
]
=
TVMGetLastError
();
has_error_
=
true
;
if
(
num_pending_
==
0
)
{
lock
.
unlock
();
cv_
.
notify_one
();
}
}
// Signal that one job has finished.
void
SignalJobFinish
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
--
num_pending_
;
if
(
num_pending_
==
0
)
{
lock
.
unlock
();
cv_
.
notify_one
();
}
}
// Get thread local version of the store.
static
ParallelLauncher
*
ThreadLocal
()
{
return
dmlc
::
ThreadLocalStore
<
ParallelLauncher
>::
Get
();
}
// The parallel lambda
FTVMParallelLambda
flambda
;
// The closure data
void
*
cdata
;
// Local env
TVMParallelGroupEnv
env
;
// Whether this thread is worker of the pool.
// used to prevent recursive launch.
bool
is_worker
{
false
};
private
:
// The mutex to access local env.
std
::
mutex
mutex_
;
// The conditional variable.
std
::
condition_variable
cv_
;
// The pending jobs.
uint32_t
num_pending_
;
// Whether error has been countered.
bool
has_error_
;
// The counter page.
std
::
atomic
<
int32_t
>*
sync_counter_
{
nullptr
};
// The error message
std
::
vector
<
std
::
string
>
par_errors_
;
};
/*! \brief Working queue for each thread */
class
ParallelTaskQueue
{
public
:
/*! \brief The task entry */
struct
Task
{
ParallelLauncher
*
launcher
;
int32_t
task_id
;
};
ParallelTaskQueue
()
{
ring_
.
resize
(
2
);
}
/*!
* \brief Signal to kill the job.
*/
void
SignalForKill
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
exit_now_
.
store
(
true
);
cv_
.
notify_all
();
}
/*!
* \brief Push task into the queue.
* \param task The task to be pushed.
*/
void
Push
(
Task
task
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
num_pending_
<
ring_
.
size
())
{
CHECK_NE
(
ring_
.
size
(),
0U
);
ring_
[(
head_
+
num_pending_
)
%
ring_
.
size
()]
=
task
;
++
num_pending_
;
}
else
{
size_t
old_size
=
ring_
.
size
();
ring_
.
resize
(
old_size
*
2
);
if
(
head_
+
num_pending_
>
old_size
)
{
// copy the ring overflow part into the tail.
size_t
ncopy
=
head_
+
num_pending_
-
old_size
;
memcpy
(
&
ring_
[
0
]
+
old_size
,
&
ring_
[
0
],
ncopy
*
sizeof
(
Task
));
}
ring_
[(
head_
+
num_pending_
)
%
ring_
.
size
()]
=
task
;
++
num_pending_
;
}
if
(
nwait_consumer_
!=
0
)
{
lock
.
unlock
();
cv_
.
notify_one
();
}
}
/*!
* \brief Pop task from the queue
* \param task The task to be poped.
* \param timeout The number of cycles to spin before sleep.
* \return Whether pop is successful or we need to exit now.
*/
bool
Pop
(
Task
*
task
,
int
timeout
=
10
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
num_pending_
!=
0
)
{
*
task
=
ring_
[
head_
];
head_
=
(
head_
+
1
)
%
ring_
.
size
();
--
num_pending_
;
if
(
exit_now_
.
load
())
return
false
;
}
else
{
lock
.
unlock
();
// do a bit spin and busy waiting before sleep.
for
(
int
i
=
0
;
i
<
timeout
&&
num_pending_
==
0
;
++
i
)
{
std
::
this_thread
::
yield
();
}
lock
.
lock
();
++
nwait_consumer_
;
cv_
.
wait
(
lock
,
[
this
]
{
return
num_pending_
!=
0
||
exit_now_
.
load
();
});
--
nwait_consumer_
;
*
task
=
ring_
[
head_
];
head_
=
(
head_
+
1
)
%
ring_
.
size
();
--
num_pending_
;
if
(
exit_now_
.
load
())
return
false
;
}
return
true
;
}
private
:
// Number of the elments in the queue
uint32_t
num_pending_
{
0
};
// Queue head
uint32_t
head_
{
0
};
// Number of consumers to wait.
uint32_t
nwait_consumer_
{
0
};
// internal mutex
std
::
mutex
mutex_
;
// cv for consumer
std
::
condition_variable
cv_
;
// signal for exit now
std
::
atomic
<
bool
>
exit_now_
{
false
};
// The internal ring.
std
::
vector
<
Task
>
ring_
;
};
// The thread pool
class
ThreadPool
{
public
:
ThreadPool
()
{
const
char
*
val
=
getenv
(
"TVM_NUM_THREADS"
);
if
(
val
==
nullptr
)
{
val
=
getenv
(
"OMP_NUM_THREADS"
);
}
if
(
val
!=
nullptr
)
{
num_workers_
=
atoi
(
val
);
}
else
{
#if defined(_M_X64) || defined(__x86_64__)
// Half to not count hyper threading.
num_workers_
=
std
::
thread
::
hardware_concurrency
()
/
2
;
#else
num_workers_
=
std
::
thread
::
hardware_concurrency
();
#endif
}
num_workers_
=
std
::
max
(
num_workers_
,
1
);
this
->
Init
();
}
~
ThreadPool
()
{
for
(
std
::
unique_ptr
<
ParallelTaskQueue
>&
q
:
queues_
)
{
q
->
SignalForKill
();
}
for
(
std
::
thread
&
t
:
threads_
)
{
t
.
join
();
}
}
int
Launch
(
FTVMParallelLambda
flambda
,
void
*
cdata
,
int
num_task
,
int
need_sync
)
{
ParallelLauncher
*
launcher
=
ParallelLauncher
::
ThreadLocal
();
CHECK
(
!
launcher
->
is_worker
)
<<
"Cannot launch parallel job inside worker, consider fuse then parallel"
;
if
(
num_task
==
0
)
{
num_task
=
num_workers_
;
}
if
(
need_sync
!=
0
)
{
CHECK_LE
(
num_task
,
num_workers_
)
<<
"Request parallel sync task larger than number of threads available "
<<
" workers="
<<
num_workers_
<<
" request="
<<
num_task
;
}
launcher
->
Init
(
flambda
,
cdata
,
num_task
,
need_sync
!=
0
);
ParallelTaskQueue
::
Task
tsk
;
tsk
.
launcher
=
launcher
;
for
(
int
i
=
0
;
i
<
num_task
;
++
i
)
{
tsk
.
task_id
=
i
;
queues_
[
i
]
->
Push
(
tsk
);
}
return
launcher
->
WaitForJobs
();
}
static
ThreadPool
*
Global
()
{
static
ThreadPool
inst
;
return
&
inst
;
}
private
:
// Initialize the pool.
void
Init
()
{
for
(
int
i
=
0
;
i
<
num_workers_
;
++
i
)
{
queues_
.
emplace_back
(
std
::
unique_ptr
<
ParallelTaskQueue
>
(
new
ParallelTaskQueue
()));
}
threads_
.
resize
(
num_workers_
);
for
(
int
i
=
0
;
i
<
num_workers_
;
++
i
)
{
threads_
[
i
]
=
std
::
thread
([
this
,
i
]
{
this
->
RunWorker
(
queues_
[
i
].
get
());
});
}
}
// Internal worker function.
void
RunWorker
(
ParallelTaskQueue
*
queue
)
{
ParallelTaskQueue
::
Task
task
;
ParallelLauncher
::
ThreadLocal
()
->
is_worker
=
true
;
while
(
queue
->
Pop
(
&
task
))
{
CHECK
(
task
.
launcher
!=
nullptr
);
TVMParallelGroupEnv
*
penv
=
&
(
task
.
launcher
->
env
);
void
*
cdata
=
task
.
launcher
->
cdata
;
if
((
*
task
.
launcher
->
flambda
)(
task
.
task_id
,
penv
,
cdata
)
==
0
)
{
task
.
launcher
->
SignalJobFinish
();
}
else
{
task
.
launcher
->
SignalJobError
(
task
.
task_id
);
}
}
}
// Number of workers
int
num_workers_
;
std
::
vector
<
std
::
unique_ptr
<
ParallelTaskQueue
>
>
queues_
;
std
::
vector
<
std
::
thread
>
threads_
;
};
}
// namespace runtime
}
// namespace tvm
int
TVMBackendParallelLaunch
(
FTVMParallelLambda
flambda
,
void
*
cdata
,
int
num_task
)
{
return
tvm
::
runtime
::
ThreadPool
::
Global
()
->
Launch
(
flambda
,
cdata
,
num_task
,
1
);
}
int
TVMBackendParallelBarrier
(
int
task_id
,
TVMParallelGroupEnv
*
penv
)
{
using
tvm
::
runtime
::
kSyncStride
;
int
num_task
=
penv
->
num_task
;
std
::
atomic
<
int
>*
sync_counter
=
reinterpret_cast
<
std
::
atomic
<
int
>*>
(
penv
->
sync_handle
);
int
old_counter
=
sync_counter
[
task_id
*
kSyncStride
].
fetch_add
(
1
,
std
::
memory_order_release
);
for
(
int
i
=
0
;
i
<
num_task
;
++
i
)
{
if
(
i
!=
task_id
)
{
while
(
sync_counter
[
i
*
kSyncStride
].
load
(
std
::
memory_order_relaxed
)
<=
old_counter
)
{
std
::
this_thread
::
yield
();
}
}
}
std
::
atomic_thread_fence
(
std
::
memory_order_acquire
);
return
0
;
}
src/schedule/schedule_lang.cc
View file @
41768cf9
...
@@ -340,6 +340,19 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*)
...
@@ -340,6 +340,19 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*)
return
*
this
;
return
*
this
;
}
}
Stage
&
Stage
::
pragma
(
IterVar
var
,
const
std
::
string
&
pragma_type
)
{
// NOLINT(*)
if
(
pragma_type
==
"unroll"
)
{
this
->
unroll
(
var
);
}
else
if
(
pragma_type
==
"vectorize"
)
{
this
->
vectorize
(
var
);
}
else
{
UpdateIterVarAttr
(
operator
->
(),
var
,
[
pragma_type
](
IterVarAttrNode
*
n
)
{
n
->
pragmas
.
push_back
(
ir
::
StringImm
::
make
(
pragma_type
));
});
}
return
*
this
;
}
Stage
&
Stage
::
prefetch
(
const
Tensor
&
tensor
,
IterVar
var
,
Expr
offset
)
{
Stage
&
Stage
::
prefetch
(
const
Tensor
&
tensor
,
IterVar
var
,
Expr
offset
)
{
StageNode
*
self
=
operator
->
();
StageNode
*
self
=
operator
->
();
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
ArrayNode
*
all_vars
=
self
->
all_iter_vars
.
CopyOnWrite
();
...
...
tests/python/unittest/test_codegen_llvm.py
View file @
41768cf9
...
@@ -28,8 +28,13 @@ def test_llvm_add_pipeline():
...
@@ -28,8 +28,13 @@ def test_llvm_add_pipeline():
C
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
T
(
*
i
),
name
=
'C'
)
C
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
T
(
*
i
),
name
=
'C'
)
s
=
tvm
.
create_schedule
(
C
.
op
)
s
=
tvm
.
create_schedule
(
C
.
op
)
xo
,
xi
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
4
)
xo
,
xi
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
4
)
s
[
C
]
.
parallel
(
xo
)
xo1
,
xo2
=
s
[
C
]
.
split
(
xo
,
factor
=
13
)
s
[
C
]
.
parallel
(
xo2
)
s
[
C
]
.
pragma
(
xo1
,
"parallel_launch_point"
)
s
[
C
]
.
pragma
(
xo2
,
"parallel_stride_pattern"
)
s
[
C
]
.
pragma
(
xo2
,
"parallel_barrier_when_finish"
)
s
[
C
]
.
vectorize
(
xi
)
s
[
C
]
.
vectorize
(
xi
)
def
check_llvm
():
def
check_llvm
():
if
not
tvm
.
module
.
enabled
(
"llvm"
):
if
not
tvm
.
module
.
enabled
(
"llvm"
):
return
return
...
@@ -167,9 +172,9 @@ def test_multiple_func():
...
@@ -167,9 +172,9 @@ def test_multiple_func():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_llvm_add_pipeline
()
test_llvm_intrin
()
test_llvm_intrin
()
test_multiple_func
()
test_multiple_func
()
test_llvm_add_pipeline
()
test_llvm_flip_pipeline
()
test_llvm_flip_pipeline
()
test_llvm_madd_pipeline
()
test_llvm_madd_pipeline
()
test_llvm_temp_space
()
test_llvm_temp_space
()
tests/python/unittest/test_codegen_vm_basic.py
View file @
41768cf9
...
@@ -74,7 +74,27 @@ def test_stack_vm_cond():
...
@@ -74,7 +74,27 @@ def test_stack_vm_cond():
np
.
testing
.
assert_equal
(
a
.
asnumpy
(),
y
)
np
.
testing
.
assert_equal
(
a
.
asnumpy
(),
y
)
run_jit
(
fapi
,
check
)
run_jit
(
fapi
,
check
)
def
test_vm_parallel
():
dtype
=
'int64'
n
=
tvm
.
var
(
'n'
)
Ab
=
tvm
.
decl_buffer
((
n
,
),
dtype
)
i
=
tvm
.
var
(
'i'
)
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
buffer_ptr
(
Ab
)
with
ib
.
for_range
(
0
,
n
,
"i"
,
for_type
=
"parallel"
)
as
i
:
A
[
i
]
=
A
[
i
]
+
1
stmt
=
ib
.
get
()
fapi
=
tvm
.
ir_pass
.
MakeAPI
(
stmt
,
"ramp"
,
[
Ab
],
0
,
True
)
fapi
=
tvm
.
ir_pass
.
LowerTVMBuiltin
(
fapi
)
def
check
(
f
):
a
=
tvm
.
nd
.
array
(
np
.
zeros
(
10
,
dtype
=
dtype
))
f
(
a
)
np
.
testing
.
assert_equal
(
a
.
asnumpy
(),
np
.
ones
(
a
.
shape
[
0
]))
run_jit
(
fapi
,
check
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_vm_parallel
()
test_stack_vm_loop
()
test_stack_vm_basic
()
test_stack_vm_basic
()
test_stack_vm_cond
()
test_stack_vm_cond
()
test_stack_vm_loop
()
tests/python/unittest/test_lang_schedule.py
View file @
41768cf9
...
@@ -30,6 +30,7 @@ def test_schedule_create():
...
@@ -30,6 +30,7 @@ def test_schedule_create():
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
assert
(
str
(
s_loaded
.
outputs
[
0
]
.
body
)
==
str
(
s
.
outputs
[
0
]
.
body
))
assert
(
str
(
s_loaded
.
outputs
[
0
]
.
body
)
==
str
(
s
.
outputs
[
0
]
.
body
))
def
test_reorder
():
def
test_reorder
():
m
=
tvm
.
var
(
'm'
)
m
=
tvm
.
var
(
'm'
)
A
=
tvm
.
placeholder
((
m
,),
name
=
'A'
)
A
=
tvm
.
placeholder
((
m
,),
name
=
'A'
)
...
@@ -91,6 +92,21 @@ def test_vectorize():
...
@@ -91,6 +92,21 @@ def test_vectorize():
assert
s
[
T
]
.
iter_var_attrs
[
xi
]
.
iter_type
==
UNROLL
assert
s
[
T
]
.
iter_var_attrs
[
xi
]
.
iter_type
==
UNROLL
assert
s
[
T
]
.
iter_var_attrs
[
yi
]
.
iter_type
==
VECTORIZE
assert
s
[
T
]
.
iter_var_attrs
[
yi
]
.
iter_type
==
VECTORIZE
def
test_pragma
():
m
=
100
A
=
tvm
.
placeholder
((
m
,),
name
=
'A'
)
T
=
tvm
.
compute
((
m
,),
lambda
i
:
A
[
i
])
s
=
tvm
.
create_schedule
(
T
.
op
)
xo
,
xi
=
s
[
T
]
.
split
(
T
.
op
.
axis
[
0
],
factor
=
10
)
s
[
T
]
.
pragma
(
xo
,
"pragma1"
)
s
[
T
]
.
pragma
(
xi
,
"vectorize"
)
VECTORIZE
=
tvm
.
schedule
.
IterVar
.
Vectorized
assert
s
[
T
]
.
iter_var_attrs
[
xo
]
.
pragmas
[
0
]
.
value
==
"pragma1"
assert
s
[
T
]
.
iter_var_attrs
[
xi
]
.
iter_type
==
VECTORIZE
def
test_rfactor
():
def
test_rfactor
():
n
=
tvm
.
var
(
'n'
)
n
=
tvm
.
var
(
'n'
)
k1
=
tvm
.
reduce_axis
((
0
,
n
),
name
=
"k1"
)
k1
=
tvm
.
reduce_axis
((
0
,
n
),
name
=
"k1"
)
...
@@ -141,6 +157,7 @@ def test_tensor_intrin():
...
@@ -141,6 +157,7 @@ def test_tensor_intrin():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_pragma
()
test_tensor_intrin
()
test_tensor_intrin
()
test_rfactor
()
test_rfactor
()
test_schedule_create
()
test_schedule_create
()
...
...
web/web_runtime.cc
View file @
41768cf9
...
@@ -50,3 +50,16 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
...
@@ -50,3 +50,16 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
});
});
}
// namespace contrib
}
// namespace contrib
}
// namespace tvm
}
// namespace tvm
// dummy parallel runtime
int
TVMBackendParallelLaunch
(
FTVMParallelLambda
flambda
,
void
*
cdata
,
int
num_task
)
{
TVMAPISetLastError
(
"Parallel is not supported in Web runtime"
);
return
-
1
;
}
int
TVMBackendParallelBarrier
(
int
task_id
,
TVMParallelGroupEnv
*
penv
)
{
return
0
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment