Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
51c40b4f
Commit
51c40b4f
authored
May 10, 2018
by
Tianqi Chen
Committed by
GitHub
May 10, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[CODEGEN] Enable cross compile of AMDGPU without rocm, update rpc (#1154)
parent
11c7b6cf
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
124 additions
and
74 deletions
+124
-74
python/tvm/contrib/rpc/proxy.py
+1
-1
python/tvm/contrib/rpc/server.py
+26
-10
python/tvm/exec/rpc_server.py
+2
-20
src/codegen/codegen_source_base.h
+7
-4
src/codegen/llvm/codegen_amdgpu.cc
+30
-8
src/codegen/source_module.cc
+34
-9
src/runtime/module.cc
+2
-2
topi/tests/python/test_topi_l2norm.py
+2
-2
topi/tests/python/test_topi_lrn.py
+2
-2
topi/tests/python_cpp/test_topi_dense.py
+2
-2
topi/tests/python_cpp/test_topi_pooling.py
+2
-2
topi/tests/python_cpp/test_topi_reduce.py
+2
-2
topi/tests/python_cpp/test_topi_transform.py
+12
-10
No files found.
python/tvm/contrib/rpc/proxy.py
View file @
51c40b4f
...
...
@@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""):
def
_connect
(
key
):
conn
=
yield
websocket
.
websocket_connect
(
url
)
on_message
=
create_on_message
(
conn
)
temp
=
_server_env
()
temp
=
_server_env
(
None
)
# Start connecton
conn
.
write_message
(
struct
.
pack
(
'@i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
key
=
"server:"
+
key
...
...
python/tvm/contrib/rpc/server.py
View file @
51c40b4f
...
...
@@ -11,6 +11,7 @@ Server is TCP based with the following protocol:
from
__future__
import
absolute_import
import
os
import
ctypes
import
socket
import
select
import
struct
...
...
@@ -21,12 +22,13 @@ import time
from
..._ffi.function
import
register_func
from
..._ffi.base
import
py_str
from
..._ffi.libinfo
import
find_lib_path
from
...module
import
load
as
_load_module
from
..
import
util
from
.
import
base
from
.
base
import
TrackerCode
def
_server_env
():
def
_server_env
(
load_library
):
"""Server environment function return temp dir"""
temp
=
util
.
tempdir
()
# pylint: disable=unused-variable
...
...
@@ -41,13 +43,21 @@ def _server_env():
m
=
_load_module
(
path
)
logging
.
info
(
"load_module
%
s"
,
path
)
return
m
libs
=
[]
load_library
=
load_library
.
split
(
":"
)
if
load_library
else
[]
for
file_name
in
load_library
:
file_name
=
find_lib_path
(
file_name
)[
0
]
libs
.
append
(
ctypes
.
CDLL
(
file_name
,
ctypes
.
RTLD_GLOBAL
))
logging
.
info
(
"Load additional library
%
s"
,
file_name
)
temp
.
libs
=
libs
return
temp
def
_serve_loop
(
sock
,
addr
):
def
_serve_loop
(
sock
,
addr
,
load_library
):
"""Server loop"""
sockfd
=
sock
.
fileno
()
temp
=
_server_env
()
temp
=
_server_env
(
load_library
)
base
.
_ServerLoop
(
sockfd
)
temp
.
remove
()
logging
.
info
(
"Finish serving
%
s"
,
addr
)
...
...
@@ -62,7 +72,7 @@ def _parse_server_opt(opts):
return
ret
def
_listen_loop
(
sock
,
port
,
rpc_key
,
tracker_addr
):
def
_listen_loop
(
sock
,
port
,
rpc_key
,
tracker_addr
,
load_library
):
"""Lisenting loop of the server master."""
def
_accept_conn
(
listen_sock
,
tracker_conn
,
ping_period
=
2
):
"""Accept connection from the other places.
...
...
@@ -162,7 +172,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# step 3: serving
logging
.
info
(
"RPCServer: connection from
%
s"
,
addr
)
server_proc
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
conn
,
addr
))
server_proc
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
conn
,
addr
,
load_library
))
server_proc
.
deamon
=
True
server_proc
.
start
()
# close from our side.
...
...
@@ -174,7 +184,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
server_proc
.
terminate
()
def
_connect_proxy_loop
(
addr
,
key
):
def
_connect_proxy_loop
(
addr
,
key
,
load_library
):
key
=
"server:"
+
key
retry_count
=
0
max_retry
=
5
...
...
@@ -198,7 +208,7 @@ def _connect_proxy_loop(addr, key):
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
process
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
sock
,
addr
))
target
=
_serve_loop
,
args
=
(
sock
,
addr
,
load_library
))
process
.
deamon
=
True
process
.
start
()
sock
.
close
()
...
...
@@ -256,6 +266,9 @@ class Server(object):
key : str, optional
The key used to identify the server in Proxy connection.
load_library : str, optional
List of additional libraries to be loaded during execution.
"""
def
__init__
(
self
,
host
,
...
...
@@ -264,7 +277,8 @@ class Server(object):
is_proxy
=
False
,
use_popen
=
False
,
tracker_addr
=
None
,
key
=
""
):
key
=
""
,
load_library
=
None
):
try
:
if
base
.
_ServerLoop
is
None
:
raise
RuntimeError
(
"Please compile with USE_RPC=1"
)
...
...
@@ -283,6 +297,8 @@ class Server(object):
assert
key
cmd
+=
[
"--tracker=
%
s:
%
d"
%
tracker_addr
,
"--key=
%
s"
%
key
]
if
load_library
:
cmd
+=
[
"--load-libary"
,
load_library
]
self
.
proc
=
multiprocessing
.
Process
(
target
=
subprocess
.
check_call
,
args
=
(
cmd
,))
self
.
proc
.
deamon
=
True
...
...
@@ -308,12 +324,12 @@ class Server(object):
self
.
sock
=
sock
self
.
proc
=
multiprocessing
.
Process
(
target
=
_listen_loop
,
args
=
(
self
.
sock
,
self
.
port
,
key
,
tracker_addr
))
self
.
sock
,
self
.
port
,
key
,
tracker_addr
,
load_library
))
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
else
:
self
.
proc
=
multiprocessing
.
Process
(
target
=
_connect_proxy_loop
,
args
=
((
host
,
port
),
key
))
target
=
_connect_proxy_loop
,
args
=
((
host
,
port
),
key
,
load_library
))
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
...
...
python/tvm/exec/rpc_server.py
View file @
51c40b4f
"""Start an RPC server"""
from
__future__
import
absolute_import
import
logging
import
argparse
import
os
import
ctypes
from
..contrib
import
rpc
from
.._ffi.libinfo
import
find_lib_path
def
main
():
"""Main funciton"""
...
...
@@ -19,26 +15,12 @@ def main():
help
=
'The end search port of the PRC'
)
parser
.
add_argument
(
'--key'
,
type
=
str
,
default
=
""
,
help
=
"RPC key used to identify the connection type."
)
parser
.
add_argument
(
'--with-executor'
,
type
=
bool
,
default
=
False
,
help
=
"Whether to load executor runtime"
)
parser
.
add_argument
(
'--load-library'
,
type
=
str
,
default
=
""
,
help
=
"Additional library to load"
)
parser
.
add_argument
(
'--tracker'
,
type
=
str
,
default
=
""
,
help
=
"Report to RPC tracker"
)
args
=
parser
.
parse_args
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
load_library
=
[
lib
for
lib
in
args
.
load_library
.
split
(
":"
)
if
len
(
lib
)
!=
0
]
curr_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
__file__
)))
apps_path
=
os
.
path
.
join
(
curr_path
,
"../../../apps/graph_executor/lib/"
)
libs
=
[]
if
args
.
with_executor
:
load_library
+=
[
"libtvm_graph_exec.so"
]
for
file_name
in
load_library
:
file_name
=
find_lib_path
(
file_name
,
apps_path
)[
0
]
libs
.
append
(
ctypes
.
CDLL
(
file_name
,
ctypes
.
RTLD_GLOBAL
))
logging
.
info
(
"Load additional library
%
s"
,
file_name
)
if
args
.
tracker
:
url
,
port
=
args
.
tracker
.
split
(
":"
)
port
=
int
(
port
)
...
...
@@ -53,8 +35,8 @@ def main():
args
.
port
,
args
.
port_end
,
key
=
args
.
key
,
tracker_addr
=
tracker_addr
)
server
.
libs
+=
libs
tracker_addr
=
tracker_addr
,
load_library
=
args
.
load_library
)
server
.
proc
.
join
()
if
__name__
==
"__main__"
:
...
...
src/codegen/codegen_source_base.h
View file @
51c40b4f
...
...
@@ -10,6 +10,7 @@
#include <tvm/codegen.h>
#include <string>
#include <vector>
#include <functional>
#include <unordered_map>
#include "../runtime/meta_data.h"
...
...
@@ -111,17 +112,19 @@ class CodeGenSourceBase {
runtime
::
Module
SourceModuleCreate
(
std
::
string
code
,
std
::
string
fmt
);
/*!
* \brief Create a source module for viewing and limited saving
* \param
code The code
to be viewed.
* \brief Create a source module for viewing and limited saving
for device.
* \param
data The code data
to be viewed.
* \param fmt The code. format.
* \param fmap The map function information map of each function.
* \param type_key The type_key of the runtime module of this source code
* \param fget_source a closure to replace default get source behavior.
*/
runtime
::
Module
DeviceSourceModuleCreate
(
std
::
string
code
,
std
::
string
data
,
std
::
string
fmt
,
std
::
unordered_map
<
std
::
string
,
runtime
::
FunctionInfo
>
fmap
,
std
::
string
type_key
);
std
::
string
type_key
,
std
::
function
<
std
::
string
(
const
std
::
string
&
)
>
fget_source
=
nullptr
);
}
// namespace codegen
}
// namespace tvm
#endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
src/codegen/llvm/codegen_amdgpu.cc
View file @
51c40b4f
...
...
@@ -4,15 +4,18 @@
* \brief AMDGPU code generator.
*/
#ifdef TVM_LLVM_VERSION
#if TVM_ROCM_RUNTIME
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include "./codegen_llvm.h"
#include "../build_common.h"
#include "../codegen_source_base.h"
#include "../../pass/ir_util.h"
#if TVM_ROCM_RUNTIME
#include "../../runtime/rocm/rocm_module.h"
#endif // TVM_ROCM_RUNTIME
namespace
tvm
{
namespace
codegen
{
...
...
@@ -131,19 +134,27 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}
};
inline
int
DetectROCMComputeVersion
()
{
inline
int
DetectROCMComputeVersion
(
const
std
::
string
&
target
)
{
size_t
pos
=
target
.
find
(
"=gfx"
);
if
(
pos
!=
std
::
string
::
npos
)
{
int
value
;
std
::
stringstream
is
(
target
.
substr
(
pos
+
4
));
if
(
is
>>
value
)
return
value
;
}
TVMContext
tvm_ctx
;
tvm_ctx
.
device_type
=
kDLROCM
;
tvm_ctx
.
device_id
=
0
;
tvm
::
runtime
::
DeviceAPI
*
api
=
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
,
true
);
if
(
api
!=
nullptr
)
{
TVMRetValue
val
;
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
)
->
GetAttr
(
tvm_ctx
,
tvm
::
runtime
::
kExist
,
&
val
);
api
->
GetAttr
(
tvm_ctx
,
tvm
::
runtime
::
kExist
,
&
val
);
if
(
val
.
operator
int
()
==
1
)
{
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
)
->
GetAttr
(
tvm_ctx
,
tvm
::
runtime
::
kComputeVersion
,
&
val
);
return
val
.
operator
int
();
}
else
{
return
803
;
}
}
LOG
(
WARNING
)
<<
"Cannot find -mcpu to specify rocm compute version assume gfx803"
;
return
803
;
}
runtime
::
Module
BuildAMDGPU
(
Array
<
LoweredFunc
>
funcs
,
std
::
string
target
)
{
...
...
@@ -151,7 +162,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
target
.
substr
(
0
,
4
)
==
"rocm"
);
std
::
ostringstream
config
;
config
<<
"-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx"
<<
DetectROCMComputeVersion
()
<<
DetectROCMComputeVersion
(
target
)
<<
target
.
substr
(
4
,
target
.
length
()
-
4
);
llvm
::
TargetMachine
*
tm
=
GetLLVMTargetMachine
(
config
.
str
());
std
::
unique_ptr
<
CodeGenAMDGPU
>
cg
(
new
CodeGenAMDGPU
());
...
...
@@ -216,7 +227,19 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
std
::
string
hsaco
=
(
*
f
)(
arr
);
std
::
string
ll
(
data_ll
.
begin
(),
data_ll
.
end
());
#if TVM_ROCM_RUNTIME
return
ROCMModuleCreate
(
hsaco
,
"hsaco"
,
ExtractFuncInfo
(
funcs
),
ll
,
assembly
);
#else
LOG
(
WARNING
)
<<
"ROCM runtime is not enabled, return a source module..."
;
auto
fget_source
=
[
ll
,
assembly
](
const
std
::
string
&
format
)
{
if
(
format
.
length
()
==
0
)
return
assembly
;
if
(
format
==
"ll"
||
format
==
"llvm"
)
return
format
;
if
(
format
==
"asm"
)
return
assembly
;
return
std
::
string
(
""
);
};
return
DeviceSourceModuleCreate
(
hsaco
,
"hsaco"
,
ExtractFuncInfo
(
funcs
),
"hsaco"
,
fget_source
);
#endif // TVM_ROCM_RUNTIME
}
TVM_REGISTER_API
(
"codegen.build_rocm"
)
...
...
@@ -226,5 +249,4 @@ TVM_REGISTER_API("codegen.build_rocm")
}
// namespace codegen
}
// namespace tvm
#endif // TVM_ROCM_RUNTIME
#endif // TVM_LLVM_VERSION
src/codegen/source_module.cc
View file @
51c40b4f
...
...
@@ -54,13 +54,34 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
}
// supports limited save without cross compile
class
DeviceSourceModuleNode
final
:
public
Source
ModuleNode
{
class
DeviceSourceModuleNode
final
:
public
runtime
::
ModuleNode
{
public
:
DeviceSourceModuleNode
(
std
::
string
code
,
DeviceSourceModuleNode
(
std
::
string
data
,
std
::
string
fmt
,
std
::
unordered_map
<
std
::
string
,
FunctionInfo
>
fmap
,
std
::
string
type_key
)
:
SourceModuleNode
(
code
,
fmt
),
fmap_
(
fmap
),
type_key_
(
type_key
)
{}
std
::
string
type_key
,
std
::
function
<
std
::
string
(
const
std
::
string
&
)
>
fget_source
)
:
data_
(
data
),
fmt_
(
fmt
),
fmap_
(
fmap
),
type_key_
(
type_key
),
fget_source_
(
fget_source
)
{}
PackedFunc
GetFunction
(
const
std
::
string
&
name
,
const
std
::
shared_ptr
<
ModuleNode
>&
sptr_to_self
)
final
{
LOG
(
FATAL
)
<<
"Source module cannot execute, to get executable module"
<<
" build TVM with
\'
"
<<
fmt_
<<
"
\'
runtime support"
;
return
PackedFunc
();
}
std
::
string
GetSource
(
const
std
::
string
&
format
)
final
{
if
(
fget_source_
!=
nullptr
)
{
return
fget_source_
(
format
);
}
else
{
return
data_
;
}
}
const
char
*
type_key
()
const
{
return
type_key_
.
c_str
();
...
...
@@ -73,27 +94,31 @@ class DeviceSourceModuleNode final : public SourceModuleNode {
<<
"Can only save to format="
<<
fmt_
;
std
::
string
meta_file
=
GetMetaFilePath
(
file_name
);
SaveMetaDataToFile
(
meta_file
,
fmap_
);
SaveBinaryToFile
(
file_name
,
code
_
);
SaveBinaryToFile
(
file_name
,
data
_
);
}
void
SaveToBinary
(
dmlc
::
Stream
*
stream
)
final
{
stream
->
Write
(
fmt_
);
stream
->
Write
(
fmap_
);
stream
->
Write
(
code
_
);
stream
->
Write
(
data
_
);
}
private
:
std
::
string
data_
;
std
::
string
fmt_
;
std
::
unordered_map
<
std
::
string
,
FunctionInfo
>
fmap_
;
std
::
string
type_key_
;
std
::
function
<
std
::
string
(
const
std
::
string
&
)
>
fget_source_
;
};
runtime
::
Module
DeviceSourceModuleCreate
(
std
::
string
code
,
std
::
string
data
,
std
::
string
fmt
,
std
::
unordered_map
<
std
::
string
,
FunctionInfo
>
fmap
,
std
::
string
type_key
)
{
std
::
string
type_key
,
std
::
function
<
std
::
string
(
const
std
::
string
&
)
>
fget_source
)
{
std
::
shared_ptr
<
DeviceSourceModuleNode
>
n
=
std
::
make_shared
<
DeviceSourceModuleNode
>
(
code
,
fmt
,
fmap
,
type_key
);
std
::
make_shared
<
DeviceSourceModuleNode
>
(
data
,
fmt
,
fmap
,
type_key
,
fget_source
);
return
runtime
::
Module
(
n
);
}
...
...
src/runtime/module.cc
View file @
51c40b4f
...
...
@@ -121,9 +121,9 @@ bool RuntimeEnabled(const std::string& target) {
}
else
if
(
target
==
"vpi"
||
target
==
"verilog"
)
{
f_name
=
"device_api.vpi"
;
}
else
if
(
target
.
length
()
>=
5
&&
target
.
substr
(
0
,
5
)
==
"nvptx"
)
{
f_name
=
"
codegen.build_nvptx
"
;
f_name
=
"
device_api.gpu
"
;
}
else
if
(
target
.
length
()
>=
4
&&
target
.
substr
(
0
,
4
)
==
"rocm"
)
{
f_name
=
"
codegen.build_
rocm"
;
f_name
=
"
device_api.
rocm"
;
}
else
if
(
target
.
length
()
>=
4
&&
target
.
substr
(
0
,
4
)
==
"llvm"
)
{
const
PackedFunc
*
pf
=
runtime
::
Registry
::
Get
(
"codegen.llvm_target_enabled"
);
if
(
pf
==
nullptr
)
return
false
;
...
...
topi/tests/python/test_topi_l2norm.py
View file @
51c40b4f
...
...
@@ -41,13 +41,13 @@ def verify_l2norm(n, c, h, w, eps, axis=None):
b_np
=
l2norm_instance_python
(
a_np
,
eps
,
axis
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_l2norm
(
B
)
ctx
=
tvm
.
context
(
device
,
0
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
B
.
shape
),
dtype
=
dtype
),
ctx
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
device
)
...
...
topi/tests/python/test_topi_lrn.py
View file @
51c40b4f
...
...
@@ -70,13 +70,13 @@ def verify_lrn(shape, size, axis, bias, alpha, beta):
b_np
=
lrn_python
(
a_np
,
size
,
axis
,
bias
,
alpha
,
beta
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_lrn
(
B
)
ctx
=
tvm
.
context
(
device
,
0
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
B
.
shape
),
dtype
=
dtype
),
ctx
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
device
)
...
...
topi/tests/python_cpp/test_topi_dense.py
View file @
51c40b4f
...
...
@@ -29,7 +29,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
a_np
,
b_np
,
c_np
,
d_np
=
get_ref_data
()
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
...
...
@@ -40,7 +41,6 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
s
=
topi
.
cpp
.
rocm
.
schedule_dense
(
target
,
[
D
])
else
:
s
=
topi
.
cpp
.
cuda
.
schedule_dense
(
target
,
[
D
])
ctx
=
tvm
.
context
(
device
,
0
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
b_np
,
ctx
)
c
=
tvm
.
nd
.
array
(
c_np
,
ctx
)
...
...
topi/tests/python_cpp/test_topi_pooling.py
View file @
51c40b4f
...
...
@@ -48,7 +48,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
b_np
=
np
.
maximum
(
b_np
,
0.0
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
...
...
@@ -57,7 +58,6 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
s
=
topi
.
cpp
.
generic
.
default_schedule
(
target
,
[
B
],
False
)
else
:
s
=
topi
.
cpp
.
cuda
.
schedule_pool
(
target
,
[
B
])
ctx
=
tvm
.
context
(
device
,
0
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
get_const_tuple
(
B
.
shape
),
dtype
=
dtype
),
ctx
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
device
)
...
...
topi/tests/python_cpp/test_topi_reduce.py
View file @
51c40b4f
...
...
@@ -46,7 +46,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
raise
NotImplementedError
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
...
...
@@ -56,7 +57,6 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
else
:
s
=
topi
.
cpp
.
cuda
.
schedule_reduce
(
target
,
[
B
])
ctx
=
tvm
.
context
(
device
,
0
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"sum"
)
# Test
in_npy
=
np
.
random
.
uniform
(
size
=
in_shape
)
.
astype
(
np
.
float32
)
...
...
topi/tests/python_cpp/test_topi_transform.py
View file @
51c40b4f
...
...
@@ -7,7 +7,8 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
A
=
tvm
.
placeholder
(
shape
=
in_shape
,
name
=
"A"
)
B
=
topi
.
cpp
.
expand_dims
(
A
,
axis
,
num_newaxis
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
...
...
@@ -16,7 +17,6 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
s
=
topi
.
cpp
.
generic
.
schedule_injective
(
target
,
[
B
])
else
:
s
=
topi
.
cpp
.
cuda
.
schedule_injective
(
target
,
[
B
])
ctx
=
tvm
.
context
(
device
,
0
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"expand_dims"
)
data_npy
=
np
.
random
.
uniform
(
size
=
in_shape
)
.
astype
(
A
.
dtype
)
out_npy
=
data_npy
.
reshape
(
out_shape
)
...
...
@@ -33,7 +33,8 @@ def verify_tranpose(in_shape, axes):
A
=
tvm
.
placeholder
(
shape
=
in_shape
,
name
=
"A"
)
B
=
topi
.
cpp
.
transpose
(
A
,
axes
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
...
...
@@ -59,7 +60,8 @@ def verify_reshape(src_shape, dst_shape):
A
=
tvm
.
placeholder
(
shape
=
src_shape
,
name
=
"A"
)
B
=
topi
.
cpp
.
reshape
(
A
,
dst_shape
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
...
...
@@ -68,7 +70,6 @@ def verify_reshape(src_shape, dst_shape):
s
=
topi
.
cpp
.
generic
.
schedule_injective
(
target
,
[
B
])
else
:
s
=
topi
.
cpp
.
cuda
.
schedule_injective
(
target
,
[
B
])
ctx
=
tvm
.
context
(
device
,
0
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"reshape"
)
data_npy
=
np
.
random
.
normal
(
size
=
src_shape
)
.
astype
(
A
.
dtype
)
out_npy
=
np
.
reshape
(
data_npy
,
newshape
=
dst_shape
)
...
...
@@ -85,7 +86,8 @@ def verify_squeeze(src_shape, axis):
A
=
tvm
.
placeholder
(
shape
=
src_shape
,
name
=
"A"
)
B
=
topi
.
cpp
.
squeeze
(
A
,
axis
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
...
...
@@ -94,7 +96,6 @@ def verify_squeeze(src_shape, axis):
s
=
topi
.
cpp
.
generic
.
schedule_injective
(
target
,
[
B
])
else
:
s
=
topi
.
cpp
.
cuda
.
schedule_injective
(
target
,
[
B
])
ctx
=
tvm
.
context
(
device
,
0
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"squeeze"
)
data_npy
=
np
.
random
.
normal
(
size
=
src_shape
)
.
astype
(
A
.
dtype
)
out_npy
=
np
.
squeeze
(
data_npy
,
axis
=
axis
)
...
...
@@ -116,7 +117,8 @@ def verify_concatenate(shapes, axis):
tensor_l
.
append
(
tvm
.
placeholder
(
shape
,
name
=
"A"
+
str
(
i
)))
out_tensor
=
topi
.
cpp
.
concatenate
(
tensor_l
,
axis
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
...
...
@@ -125,7 +127,6 @@ def verify_concatenate(shapes, axis):
s
=
topi
.
cpp
.
generic
.
schedule_injective
(
target
,
[
out_tensor
])
else
:
s
=
topi
.
cpp
.
cuda
.
schedule_injective
(
target
,
[
out_tensor
])
ctx
=
tvm
.
context
(
device
,
0
)
foo
=
tvm
.
build
(
s
,
tensor_l
+
[
out_tensor
],
device
,
name
=
"concatenate"
)
data_npys
=
[
np
.
random
.
normal
(
size
=
shape
)
.
astype
(
tensor_l
[
0
]
.
dtype
)
for
shape
in
shapes
]
out_npy
=
np
.
concatenate
(
data_npys
,
axis
=
axis
)
...
...
@@ -143,7 +144,8 @@ def verify_split(src_shape, indices_or_sections, axis):
tensor_l
=
topi
.
cpp
.
split
(
A
,
indices_or_sections
,
axis
)
tensor_l
=
list
(
tensor_l
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment