Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f6c043eb
Commit
f6c043eb
authored
Feb 25, 2017
by
Tianqi Chen
Committed by
GitHub
Feb 25, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[LLVM/RUNTIME] Support Parallel for on CPU (#54)
parent
2f462cca
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
298 additions
and
86 deletions
+298
-86
include/tvm/ir_pass.h
+5
-4
include/tvm/runtime/c_runtime_api.h
+27
-7
include/tvm/schedule.h
+8
-1
python/tvm/schedule.py
+10
-0
src/api/api_lang.cc
+6
-0
src/codegen/llvm/codegen_llvm.cc
+111
-32
src/codegen/llvm/codegen_llvm.h
+8
-1
src/lang/lowered_func.cc
+16
-0
src/pass/make_api.cc
+1
-1
src/pass/split_host_device.cc
+3
-3
src/runtime/c_runtime_api.cc
+65
-4
src/schedule/schedule_lang.cc
+6
-0
src/schedule/schedule_ops.cc
+1
-0
tests/python/unittest/test_codegen_llvm.py
+31
-0
tests/python/unittest/test_codegen_vm_basic.py
+0
-33
No files found.
include/tvm/ir_pass.h
View file @
f6c043eb
...
...
@@ -173,11 +173,12 @@ LoweredFunc MakeAPI(Stmt body,
int
num_unpacked_args
);
/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
Array
<
Var
>
UndefinedVars
(
const
LoweredFunc
&
f
);
Array
<
Var
>
UndefinedVars
(
const
Stmt
&
stmt
,
const
Array
<
Var
>&
defs
);
/*!
* \brief Split the function into a host function and device functions.
...
...
include/tvm/runtime/c_runtime_api.h
View file @
f6c043eb
...
...
@@ -226,6 +226,18 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
TVMContext
ctx
);
/*!
* \brief Free the Module
* \param mod The module to be freed.
*
* \note This may not free up the module's resources.
* If there is active TVMFunctionHandle uses the module
* Or if this module is imported by another active module.
*
* The all functions remains valid until TVMFuncFree is called.
*/
TVM_DLL
int
TVMModFree
(
TVMModuleHandle
mod
);
/*!
* \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function).
*
...
...
@@ -242,17 +254,25 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
TVM_DLL
int
TVMBackendGetFuncFromEnv
(
void
*
mod_node
,
const
char
*
func_name
,
TVMFunctionHandle
*
out
);
/*!
* \brief Free the Module
* \param mod The module to be freed.
* \brief Backend function for running parallel for loop.
*
* \note This may not free up the module's resources.
* If there is active TVMFunctionHandle uses the module
* Or if this module is imported by another active module.
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* The all functions remains valid until TVMFuncFree is called.
* \param begin The start of iteration.
* \param end The end of iteration.
* \param lambda The lambda function to be executed.
* \param env The environment of lambda function.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL
int
TVMModFree
(
TVMModuleHandle
mod
);
TVM_DLL
int
TVMBackendParallelFor
(
int64_t
begin
,
int64_t
end
,
int
(
*
lambda
)(
int64_t
begin
,
int64_t
end
,
void
*
env
),
void
*
env
);
/*!
* \brief Free the function when it is no longer needed.
...
...
include/tvm/schedule.h
View file @
f6c043eb
...
...
@@ -34,7 +34,8 @@ enum AttachType : int {
/*! \brief IterVar type */
enum
IterVarType
:
int
{
kUnrolled
=
1
,
kVectorized
=
2
kVectorized
=
2
,
kParallel
=
3
};
/*! \brief Stage, contains scheduling for a stage of computation. */
...
...
@@ -153,6 +154,12 @@ class Stage : public NodeRef {
*/
Stage
&
unroll
(
IterVar
var
);
// NOLINT(*)
/*!
* \brief Parallelize iteration.
* \param var The axis to be parallelized.
* \return reference to self.
*/
Stage
&
parallel
(
IterVar
var
);
// NOLINT(*)
/*!
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
*/
...
...
python/tvm/schedule.py
View file @
f6c043eb
...
...
@@ -257,3 +257,13 @@ class Stage(NodeBase):
The iteration to be unrolled.
"""
_api_internal
.
_StageUnroll
(
self
,
var
)
def
parallel
(
self
,
var
):
"""Parallelize the iteration.
Parameters
----------
var : IterVar
The iteration to be parallelized.
"""
_api_internal
.
_StageParallel
(
self
,
var
)
src/api/api_lang.cc
View file @
f6c043eb
...
...
@@ -280,6 +280,12 @@ TVM_REGISTER_API(_StageVectorize)
.
vectorize
(
args
[
1
]);
});
TVM_REGISTER_API
(
_StageParallel
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
parallel
(
args
[
1
]);
});
TVM_REGISTER_API
(
_ScheduleNormalize
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Schedule
()
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
f6c043eb
...
...
@@ -5,6 +5,7 @@
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/ir_pass.h>
#include "./codegen_llvm.h"
#include "../../arithmetic/compute_expr.h"
...
...
@@ -30,6 +31,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_int8_
=
llvm
::
Type
::
getInt8Ty
(
*
ctx
);
t_int16_
=
llvm
::
Type
::
getInt16Ty
(
*
ctx
);
t_int32_
=
llvm
::
Type
::
getInt32Ty
(
*
ctx
);
t_int64_
=
llvm
::
Type
::
getInt64Ty
(
*
ctx
);
t_float64_
=
llvm
::
Type
::
getDoubleTy
(
*
ctx
);
t_tvm_index_
=
llvm
::
Type
::
getIntNTy
(
*
ctx
,
sizeof
(
tvm_index_t
)
*
8
);
t_tvm_context_
=
llvm
::
StructType
::
create
({
t_int_
,
t_int_
});
...
...
@@ -43,6 +45,8 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_type_
,
t_tvm_context_
});
t_tvm_value_
=
llvm
::
StructType
::
create
({
t_float64_
});
t_f_tvm_par_for_lambda_
=
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_int64_
,
t_int64_
,
t_void_p_
},
false
);
md_builder_
.
reset
(
new
llvm
::
MDBuilder
(
*
ctx
));
md_very_likely_branch_
=
md_builder_
->
createBranchWeights
(
1
<<
30
,
0
);
...
...
@@ -70,7 +74,11 @@ void CodeGenLLVM::Init(const std::string& module_name,
f_tvm_api_set_last_error_
=
llvm
::
Function
::
Create
(
llvm
::
FunctionType
::
get
(
t_void_
,
{
t_char_
->
getPointerTo
()},
false
),
llvm
::
Function
::
ExternalLinkage
,
"TVMAPISetLastError"
,
module_
.
get
());
f_tvm_parallel_for_
=
llvm
::
Function
::
Create
(
llvm
::
FunctionType
::
get
(
t_int_
,
{
t_int64_
,
t_int64_
,
t_f_tvm_par_for_lambda_
->
getPointerTo
(),
t_void_p_
}
,
false
),
llvm
::
Function
::
ExternalLinkage
,
"TVMBackendParallelFor"
,
module_
.
get
());
this
->
InitTarget
(
target_triple
);
// initialize builder
builder_
.
reset
(
new
IRBuilder
(
*
ctx
));
...
...
@@ -141,7 +149,9 @@ void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
}
llvm
::
BasicBlock
*
block
=
llvm
::
BasicBlock
::
Create
(
*
ctx_
,
"entry"
,
function_
);
builder_
->
SetInsertPoint
(
block
);
builder_
->
CreateRet
(
builder_
->
CreateCall
(
f
,
args
));
llvm
::
CallInst
*
call
=
builder_
->
CreateCall
(
f
,
args
);
call
->
setTailCall
(
true
);
builder_
->
CreateRet
(
call
);
}
class
FPassManager
:
public
llvm
::
legacy
::
FunctionPassManager
{
...
...
@@ -545,7 +555,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
return
nullptr
;
}
llvm
::
BasicBlock
*
CodeGenLLVM
::
Check
Packed
CallSuccess
(
llvm
::
Value
*
retcode
)
{
llvm
::
BasicBlock
*
CodeGenLLVM
::
CheckCallSuccess
(
llvm
::
Value
*
retcode
)
{
// create emit codes that checks and load the function.
using
llvm
::
BasicBlock
;
BasicBlock
*
fail_block
=
BasicBlock
::
Create
(
...
...
@@ -563,34 +573,15 @@ llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) {
return
end_block
;
}
void
CodeGenLLVM
::
Visit_
(
const
For
*
op
)
{
using
llvm
::
BasicBlock
;
BasicBlock
*
for_head
=
BasicBlock
::
Create
(
*
ctx_
,
"for_head"
,
function_
);
BasicBlock
*
for_body
=
BasicBlock
::
Create
(
*
ctx_
,
"for_body"
,
function_
);
BasicBlock
*
for_end
=
BasicBlock
::
Create
(
*
ctx_
,
"for_end"
,
function_
);
BasicBlock
*
pre_block
=
builder_
->
GetInsertBlock
();
CHECK
(
is_zero
(
op
->
min
));
Type
t
=
op
->
min
.
type
();
llvm
::
Value
*
init
=
ConstInt32
(
0
);
llvm
::
Value
*
extent
=
MakeValue
(
op
->
extent
);
builder_
->
CreateBr
(
for_head
);
builder_
->
SetInsertPoint
(
for_head
);
llvm
::
PHINode
*
index
=
builder_
->
CreatePHI
(
LLVMType
(
t
),
2
);
index
->
addIncoming
(
init
,
pre_block
);
llvm
::
Value
*
cond
=
CreateLT
(
t
,
index
,
extent
);
builder_
->
CreateCondBr
(
cond
,
for_body
,
for_end
,
md_very_likely_branch_
);
// body of for
builder_
->
SetInsertPoint
(
for_body
);
var_map_
[
op
->
loop_var
.
get
()]
=
index
;
this
->
Visit
(
op
->
body
);
llvm
::
Value
*
next_index
=
CreateAdd
(
t
,
index
,
ConstInt32
(
1
));
index
->
addIncoming
(
next_index
,
builder_
->
GetInsertBlock
());
builder_
->
CreateBr
(
for_head
);
// end of for
builder_
->
SetInsertPoint
(
for_end
);
if
(
op
->
for_type
==
ForType
::
Serial
)
{
CreateSerialFor
(
ConstInt32
(
0
),
MakeValue
(
op
->
extent
),
op
->
loop_var
,
op
->
body
);
}
else
if
(
op
->
for_type
==
ForType
::
Parallel
)
{
CreateParallelFor
(
op
);
}
else
{
LOG
(
FATAL
)
<<
"cannot handle for type "
<<
op
->
for_type
;
}
}
void
CodeGenLLVM
::
Visit_
(
const
IfThenElse
*
op
)
{
...
...
@@ -807,7 +798,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
llvm
::
Value
*
ctx
=
builder_
->
CreateLoad
(
gv_mod_ctx_
);
llvm
::
Value
*
retcode
=
builder_
->
CreateCall
(
f_tvm_get_func_from_env_
,
{
ctx
,
GetConstString
(
fname
),
out
});
init_block
=
Check
Packed
CallSuccess
(
retcode
);
init_block
=
CheckCallSuccess
(
retcode
);
llvm
::
Value
*
loaded_handle
=
builder_
->
CreateAlignedLoad
(
out
,
align
);
builder_
->
CreateBr
(
end_block
);
// end block
...
...
@@ -846,7 +837,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
}
llvm
::
Value
*
ret_value
=
builder_
->
CreateAlloca
(
t_tvm_value_
);
llvm
::
Value
*
ret_tcode
=
builder_
->
CreateAlloca
(
t_int_
);
Check
Packed
CallSuccess
(
CheckCallSuccess
(
builder_
->
CreateCall
(
f_tvm_func_call_
,
{
handle
,
targs
,
tcodes
,
ConstInt32
(
nargs
),
ret_value
,
ret_tcode
}));
...
...
@@ -934,6 +925,94 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
}
}
void
CodeGenLLVM
::
CreateParallelFor
(
const
For
*
op
)
{
using
llvm
::
BasicBlock
;
llvm
::
Value
*
min
=
MakeValue
(
op
->
min
);
llvm
::
Value
*
extent
=
MakeValue
(
op
->
extent
);
min
=
builder_
->
CreateIntCast
(
min
,
t_int64_
,
op
->
min
.
type
().
is_int
());
extent
=
builder_
->
CreateIntCast
(
extent
,
t_int64_
,
op
->
min
.
type
().
is_int
());
// fields to be packed into closure.
Var
loop_var
(
op
->
loop_var
.
node_
);
Array
<
Var
>
vfields
=
ir
::
UndefinedVars
(
op
->
body
,
{
loop_var
});
std
::
vector
<
llvm
::
Type
*>
fields
;
for
(
Var
v
:
vfields
)
{
auto
it
=
var_map_
.
find
(
v
.
get
());
CHECK
(
it
!=
var_map_
.
end
());
fields
.
push_back
(
it
->
second
->
getType
());
}
// closure data
llvm
::
StructType
*
tcdata
=
llvm
::
StructType
::
create
(
fields
);
llvm
::
Function
*
f
=
llvm
::
Function
::
Create
(
t_f_tvm_par_for_lambda_
,
llvm
::
Function
::
PrivateLinkage
,
"__tvm_par_for_lambda"
,
module_
.
get
());
// allocate and setup the closure, call the closure.
llvm
::
Value
*
cdata
=
builder_
->
CreateAlloca
(
tcdata
,
ConstInt32
(
1
));
llvm
::
Value
*
zero
=
ConstInt32
(
0
);
for
(
size_t
i
=
0
;
i
<
vfields
.
size
();
++
i
)
{
builder_
->
CreateStore
(
var_map_
.
at
(
vfields
[
i
].
get
()),
builder_
->
CreateInBoundsGEP
(
cdata
,
{
zero
,
ConstInt32
(
i
)}));
}
BasicBlock
*
par_for_end
=
CheckCallSuccess
(
builder_
->
CreateCall
(
f_tvm_parallel_for_
,
{
min
,
extent
,
f
,
builder_
->
CreatePointerCast
(
cdata
,
t_void_p_
)}));
// Setup the closure function.
BasicBlock
*
lambda_entry
=
BasicBlock
::
Create
(
*
ctx_
,
"entry"
,
f
);
builder_
->
SetInsertPoint
(
lambda_entry
);
auto
it
=
f
->
arg_begin
();
llvm
::
Value
*
begin
=
&
(
*
it
++
);
llvm
::
Value
*
end
=
&
(
*
it
++
);
cdata
=
&
(
*
it
++
);
begin
=
CreateCast
(
Int
(
64
),
op
->
loop_var
.
type
(),
begin
);
end
=
CreateCast
(
Int
(
64
),
op
->
loop_var
.
type
(),
end
);
cdata
=
builder_
->
CreatePointerCast
(
cdata
,
tcdata
->
getPointerTo
());
// setup new variable map, swap it with current var context.
std
::
unordered_map
<
const
Variable
*
,
llvm
::
Value
*>
new_vmap
;
for
(
size_t
i
=
0
;
i
<
vfields
.
size
();
++
i
)
{
new_vmap
[
vfields
[
i
].
get
()]
=
builder_
->
CreateLoad
(
builder_
->
CreateInBoundsGEP
(
cdata
,
{
zero
,
ConstInt32
(
i
)}));
}
std
::
swap
(
function_
,
f
);
std
::
swap
(
new_vmap
,
var_map_
);
CreateSerialFor
(
begin
,
end
,
op
->
loop_var
,
op
->
body
);
builder_
->
CreateRet
(
ConstInt32
(
0
));
// swap the var map back, now we are back on track.
std
::
swap
(
new_vmap
,
var_map_
);
std
::
swap
(
function_
,
f
);
builder_
->
SetInsertPoint
(
par_for_end
);
}
void
CodeGenLLVM
::
CreateSerialFor
(
llvm
::
Value
*
begin
,
llvm
::
Value
*
end
,
const
VarExpr
&
loop_var
,
const
Stmt
&
body
)
{
using
llvm
::
BasicBlock
;
Type
t
=
loop_var
.
type
();
BasicBlock
*
for_head
=
BasicBlock
::
Create
(
*
ctx_
,
"for_head"
,
function_
);
BasicBlock
*
for_body
=
BasicBlock
::
Create
(
*
ctx_
,
"for_body"
,
function_
);
BasicBlock
*
for_end
=
BasicBlock
::
Create
(
*
ctx_
,
"for_end"
,
function_
);
BasicBlock
*
pre_block
=
builder_
->
GetInsertBlock
();
builder_
->
CreateBr
(
for_head
);
builder_
->
SetInsertPoint
(
for_head
);
llvm
::
PHINode
*
index
=
builder_
->
CreatePHI
(
begin
->
getType
(),
2
);
index
->
addIncoming
(
begin
,
pre_block
);
llvm
::
Value
*
cond
=
CreateLT
(
t
,
index
,
end
);
builder_
->
CreateCondBr
(
cond
,
for_body
,
for_end
,
md_very_likely_branch_
);
// body of for
builder_
->
SetInsertPoint
(
for_body
);
var_map_
[
loop_var
.
get
()]
=
index
;
this
->
Visit
(
body
);
llvm
::
Value
*
next_index
=
CreateAdd
(
t
,
index
,
ConstInt32
(
1
));
index
->
addIncoming
(
next_index
,
builder_
->
GetInsertBlock
());
builder_
->
CreateBr
(
for_head
);
// end of for
builder_
->
SetInsertPoint
(
for_end
);
}
}
// namespace codegen
}
// namespace tvm
#endif // TVM_LLVM_VERSION
src/codegen/llvm/codegen_llvm.h
View file @
f6c043eb
...
...
@@ -152,10 +152,12 @@ class CodeGenLLVM : public IRVisitor {
llvm
::
StructType
*
t_tvm_type_
{
nullptr
};
llvm
::
StructType
*
t_tvm_array_
{
nullptr
};
llvm
::
StructType
*
t_tvm_value_
{
nullptr
};
llvm
::
FunctionType
*
t_f_tvm_par_for_lambda_
{
nullptr
};
// tvm api functions
llvm
::
Function
*
f_tvm_func_call_
{
nullptr
};
llvm
::
Function
*
f_tvm_get_func_from_env_
{
nullptr
};
llvm
::
Function
*
f_tvm_api_set_last_error_
{
nullptr
};
llvm
::
Function
*
f_tvm_parallel_for_
{
nullptr
};
// The acting body
llvm
::
BasicBlock
*
block_
{
nullptr
};
// Last value returned codegen call.
...
...
@@ -176,10 +178,15 @@ class CodeGenLLVM : public IRVisitor {
llvm
::
Value
*
CreateBufferPtr
(
Type
t
,
llvm
::
Value
*
buffer
,
llvm
::
Value
*
index
);
llvm
::
Value
*
CreateCast
(
Type
from
,
Type
to
,
llvm
::
Value
*
value
);
llvm
::
Value
*
GetPackedFuncHandle
(
const
std
::
string
&
str
);
// Create parallel for.
void
CreateParallelFor
(
const
For
*
op
);
// Create serial for
void
CreateSerialFor
(
llvm
::
Value
*
begin
,
llvm
::
Value
*
end
,
const
VarExpr
&
loop_var
,
const
Stmt
&
body
);
// Check if the call to packed function is successful
// if not directly finalize function and pass on return code.
// return the end block after the check
llvm
::
BasicBlock
*
Check
Packed
CallSuccess
(
llvm
::
Value
*
retcode
);
llvm
::
BasicBlock
*
CheckCallSuccess
(
llvm
::
Value
*
retcode
);
// Initialize target
void
InitTarget
(
const
std
::
string
&
target
);
// Add a function to set global module context
...
...
src/lang/lowered_func.cc
0 → 100644
View file @
f6c043eb
/*!
* Copyright (c) 2017 by Contributors
* \file lowered_func.cc
*/
#include <tvm/lowered_func.h>
namespace
tvm
{
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
LoweredFuncNode
>
([](
const
LoweredFuncNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"LoweredFunc("
<<
op
->
name
<<
", "
<<
op
<<
")"
;
});
TVM_REGISTER_NODE_TYPE
(
LoweredFuncNode
);
}
// namespace tvm
src/pass/make_api.cc
View file @
f6c043eb
...
...
@@ -188,7 +188,7 @@ LoweredFunc MakeAPI(Stmt body,
n
->
is_packed_func
=
num_unpacked_args
==
0
;
n
->
body
=
MergeNest
({
seq_init
,
seq_check
},
body
);
LoweredFunc
f
(
n
);
Array
<
Var
>
undefined
=
UndefinedVars
(
f
);
Array
<
Var
>
undefined
=
UndefinedVars
(
f
->
body
,
f
->
args
);
if
(
undefined
.
size
()
!=
0
)
{
std
::
ostringstream
os
;
for
(
Var
v
:
undefined
)
{
...
...
src/pass/split_host_device.cc
View file @
f6c043eb
...
...
@@ -220,12 +220,12 @@ class HostDeviceSplitter : public IRMutator {
};
Array
<
Var
>
UndefinedVars
(
const
LoweredFunc
&
f
)
{
Array
<
Var
>
UndefinedVars
(
const
Stmt
&
stmt
,
const
Array
<
Var
>&
args
)
{
IRUseDefAnalysis
m
;
for
(
Var
arg
:
f
->
args
)
{
for
(
Var
arg
:
args
)
{
m
.
use_count_
[
arg
.
get
()]
=
0
;
}
m
.
Mutate
(
f
->
body
);
m
.
Mutate
(
stmt
);
return
m
.
undefined_
;
}
...
...
src/runtime/c_runtime_api.cc
View file @
f6c043eb
...
...
@@ -7,8 +7,11 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <dmlc/timer.h>
#include <algorithm>
#include <string>
#include <cstdlib>
#include <thread>
#include "./runtime_base.h"
#include "./device_api.h"
...
...
@@ -71,6 +74,24 @@ using namespace tvm::runtime;
struct
TVMRuntimeEntry
{
std
::
string
ret_str
;
std
::
string
last_error
;
// threads used in parallel for
std
::
vector
<
std
::
thread
>
par_threads
;
// errors created in parallel for.
std
::
vector
<
std
::
string
>
par_errors
;
// number of parallel threads
int
num_par_threads
{
1
};
TVMRuntimeEntry
()
{
const
char
*
val
=
getenv
(
"TVM_NUM_THREADS"
);
if
(
val
==
nullptr
)
{
val
=
getenv
(
"OMP_NUM_THREADS"
);
}
if
(
val
!=
nullptr
)
{
num_par_threads
=
atoi
(
val
);
}
else
{
num_par_threads
=
std
::
thread
::
hardware_concurrency
();
}
}
};
typedef
dmlc
::
ThreadLocalStore
<
TVMRuntimeEntry
>
TVMAPIRuntimeStore
;
...
...
@@ -123,6 +144,12 @@ int TVMModPreCompile(TVMModuleHandle mod,
API_END
();
}
int
TVMModFree
(
TVMModuleHandle
mod
)
{
API_BEGIN
();
delete
static_cast
<
Module
*>
(
mod
);
API_END
();
}
int
TVMBackendGetFuncFromEnv
(
void
*
mod_node
,
const
char
*
func_name
,
TVMFunctionHandle
*
func
)
{
...
...
@@ -132,10 +159,44 @@ int TVMBackendGetFuncFromEnv(void* mod_node,
API_END
();
}
int
TVMModFree
(
TVMModuleHandle
mod
)
{
API_BEGIN
();
delete
static_cast
<
Module
*>
(
mod
);
API_END
();
int
TVMBackendParallelFor
(
int64_t
begin
,
int64_t
end
,
int
(
*
lambda
)(
int64_t
begin
,
int64_t
end
,
void
*
env
),
void
*
env
)
{
TVMRuntimeEntry
*
rt
=
TVMAPIRuntimeStore
::
Get
();
int
nthread
=
rt
->
num_par_threads
;
rt
->
par_threads
.
resize
(
nthread
);
rt
->
par_errors
.
clear
();
rt
->
par_errors
.
resize
(
nthread
);
int64_t
step
=
(
end
-
begin
+
nthread
-
1
)
/
nthread
;
auto
fexec
=
[
lambda
,
env
,
begin
,
end
,
step
,
rt
](
int
i
)
{
int64_t
ibegin
=
std
::
min
(
end
,
begin
+
step
*
i
);
int64_t
iend
=
std
::
min
(
end
,
begin
+
step
*
(
i
+
1
));
int
rv
=
(
*
lambda
)(
ibegin
,
iend
,
env
);
if
(
rv
!=
0
)
{
std
::
ostringstream
os
;
os
<<
"Thread "
<<
i
<<
" error:"
<<
TVMGetLastError
();
rt
->
par_errors
[
i
]
=
os
.
str
();
}
};
for
(
int
i
=
0
;
i
<
nthread
;
++
i
)
{
rt
->
par_threads
[
i
]
=
std
::
thread
(
fexec
,
i
);
}
int
ret
=
0
;
for
(
int
i
=
0
;
i
<
nthread
;
++
i
)
{
rt
->
par_threads
[
i
].
join
();
if
(
rt
->
par_errors
[
i
].
length
()
!=
0
)
ret
=
-
1
;
}
if
(
ret
==
0
)
return
ret
;
std
::
ostringstream
os
;
for
(
int
i
=
0
;
i
<
nthread
;
++
i
)
{
if
(
rt
->
par_errors
[
i
].
length
()
!=
0
)
{
os
<<
rt
->
par_errors
[
i
]
<<
'\n'
;
}
}
rt
->
last_error
=
os
.
str
();
return
-
1
;
}
int
TVMFuncFree
(
TVMFunctionHandle
func
)
{
...
...
src/schedule/schedule_lang.cc
View file @
f6c043eb
...
...
@@ -69,6 +69,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
switch
(
op
->
iter_type
)
{
case
kUnrolled
:
p
->
stream
<<
"unroll"
;
break
;
case
kVectorized
:
p
->
stream
<<
"vectorize"
;
break
;
case
kParallel
:
p
->
stream
<<
"parallel"
;
break
;
}
});
...
...
@@ -246,6 +247,11 @@ Stage& Stage::unroll(IterVar var) { // NOLINT(*)
return
*
this
;
}
Stage
&
Stage
::
parallel
(
IterVar
var
)
{
// NOLINT(*)
SetAttr
(
operator
->
(),
var
,
IterVarAttr
(
kParallel
));
return
*
this
;
}
Schedule
::
Schedule
(
Array
<
Operation
>
ops
)
{
auto
n
=
std
::
make_shared
<
ScheduleNode
>
();
n
->
outputs
=
ops
;
...
...
src/schedule/schedule_ops.cc
View file @
f6c043eb
...
...
@@ -189,6 +189,7 @@ MakeLoopNest(const Stage& sch,
if
(
sch
->
iter_var_attrs
.
count
(
iv
))
{
switch
(
sch
->
iter_var_attrs
[
iv
]
->
iter_type
)
{
case
kUnrolled
:
for_type
=
ForType
::
Unrolled
;
break
;
case
kParallel
:
for_type
=
ForType
::
Parallel
;
break
;
case
kVectorized
:
for_type
=
ForType
::
Vectorized
;
break
;
}
}
...
...
tests/python/unittest/test_codegen_llvm.py
0 → 100644
View file @
f6c043eb
import
tvm
import
numpy
as
np
def
test_llvm_add_pipeline
():
n
=
tvm
.
Var
(
'n'
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,),
name
=
'B'
)
C
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
A
(
*
i
)
+
B
(
*
i
),
name
=
'C'
)
s
=
tvm
.
Schedule
(
C
.
op
)
s
[
C
]
.
parallel
(
C
.
op
.
axis
[
0
])
def
check_llvm
():
if
not
tvm
.
codegen
.
enabled
(
"llvm"
):
return
# build and invoke the kernel.
f
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
"llvm"
)
ctx
=
tvm
.
cpu
(
0
)
# launch the kernel.
n
=
10270
*
2460
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
B
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
C
.
dtype
),
ctx
)
for
i
in
range
(
1000
):
f
(
a
,
b
,
c
)
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
())
check_llvm
()
if
__name__
==
"__main__"
:
test_llvm_add_pipeline
()
tests/python/unittest/test_codegen_
stack_llvm
.py
→
tests/python/unittest/test_codegen_
vm_basic
.py
View file @
f6c043eb
...
...
@@ -78,40 +78,7 @@ def test_stack_vm_cond():
np
.
testing
.
assert_equal
(
a
.
asnumpy
(),
y
)
run_jit
(
fapi
,
check
)
def
test_llvm_add_pipeline
():
n
=
tvm
.
Var
(
'n'
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,),
name
=
'B'
)
C
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
A
(
*
i
)
+
B
(
*
i
),
name
=
'C'
)
s
=
tvm
.
Schedule
(
C
.
op
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
Ab
=
tvm
.
Buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
Bb
=
tvm
.
Buffer
(
B
.
shape
,
B
.
dtype
,
name
=
'B'
)
Cb
=
tvm
.
Buffer
(
C
.
shape
,
C
.
dtype
,
name
=
'C'
)
stmt
=
tvm
.
ir_pass
.
StorageFlatten
(
stmt
,
{
A
:
Ab
,
B
:
Bb
,
C
:
Cb
})
stmt
=
tvm
.
ir_pass
.
Simplify
(
stmt
)
fapi
=
tvm
.
ir_pass
.
MakeAPI
(
stmt
,
"myadd"
,
[
Ab
,
Bb
,
Cb
],
0
)
def
check_llvm
():
if
not
tvm
.
codegen
.
enabled
(
"llvm"
):
return
# build and invoke the kernel.
f
=
tvm
.
codegen
.
build
(
fapi
,
"llvm"
)
ctx
=
tvm
.
cpu
(
0
)
# launch the kernel.
n
=
1027
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
Ab
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
Bb
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
Cb
.
dtype
),
ctx
)
f
(
a
,
b
,
c
)
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
())
check_llvm
()
if
__name__
==
"__main__"
:
test_stack_vm_basic
()
test_stack_vm_cond
()
test_stack_vm_loop
()
test_llvm_add_pipeline
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment