Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
08505e34
Commit
08505e34
authored
Feb 07, 2017
by
Tianqi Chen
Committed by
GitHub
Feb 07, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ADDON] Allow piggy back nvcc compiler and code (#35)
parent
88377988
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
171 additions
and
24 deletions
+171
-24
include/tvm/runtime/c_runtime_api.h
+12
-2
include/tvm/runtime/packed_func.h
+14
-3
python/tvm/_ctypes/_function.py
+11
-2
python/tvm/_ctypes/_types.py
+25
-5
python/tvm/addon/__init__.py
+1
-0
python/tvm/addon/nvcc_compiler.py
+55
-0
src/arithmetic/canonical.cc
+9
-4
src/codegen/codegen_cuda.cc
+13
-1
src/runtime/packed_func_registry.cc
+6
-0
tests/python/integration/test_gemm.py
+16
-6
tests/python/unittest/test_runtime_packed_func.py
+9
-1
No files found.
include/tvm/runtime/c_runtime_api.h
View file @
08505e34
...
@@ -51,8 +51,9 @@ typedef enum {
...
@@ -51,8 +51,9 @@ typedef enum {
kArrayHandle
=
5U
,
kArrayHandle
=
5U
,
kTVMType
=
6U
,
kTVMType
=
6U
,
kNodeHandle
=
7U
,
kNodeHandle
=
7U
,
kStr
=
8U
,
kFuncHandle
=
8U
,
kFuncHandle
=
9U
kStr
=
9U
,
kBytes
=
10U
}
TVMTypeCode
;
}
TVMTypeCode
;
/*!
/*!
...
@@ -87,6 +88,15 @@ typedef union {
...
@@ -87,6 +88,15 @@ typedef union {
}
TVMValue
;
}
TVMValue
;
/*!
/*!
* \brief Byte array type used to pass in byte array
* When kBytes is used as data type.
*/
typedef
struct
{
const
char
*
data
;
size_t
size
;
}
TVMByteArray
;
/*!
* \brief The device type
* \brief The device type
*/
*/
typedef
enum
{
typedef
enum
{
...
...
include/tvm/runtime/packed_func.h
View file @
08505e34
...
@@ -112,6 +112,12 @@ class PackedFunc {
...
@@ -112,6 +112,12 @@ class PackedFunc {
*/
*/
static
const
PackedFunc
&
GetGlobal
(
const
std
::
string
&
name
);
static
const
PackedFunc
&
GetGlobal
(
const
std
::
string
&
name
);
/*!
/*!
* \brief Whether the global function exist
* \param name The name of the function.
* \return Whetehr the global function exist.
*/
static
bool
GlobalExist
(
const
std
::
string
&
name
);
/*!
* \brief Get the names of currently registered global function.
* \brief Get the names of currently registered global function.
*/
*/
static
std
::
vector
<
std
::
string
>
ListGlobalNames
();
static
std
::
vector
<
std
::
string
>
ListGlobalNames
();
...
@@ -267,9 +273,13 @@ class TVMArgValue : public TVMPODValue_ {
...
@@ -267,9 +273,13 @@ class TVMArgValue : public TVMPODValue_ {
operator
std
::
string
()
const
{
operator
std
::
string
()
const
{
if
(
type_code_
==
kTVMType
)
{
if
(
type_code_
==
kTVMType
)
{
return
TVMType2String
(
operator
TVMType
());
return
TVMType2String
(
operator
TVMType
());
}
else
if
(
type_code_
==
kBytes
)
{
TVMByteArray
*
arr
=
static_cast
<
TVMByteArray
*>
(
value_
.
v_handle
);
return
std
::
string
(
arr
->
data
,
arr
->
size
);
}
else
{
TVM_CHECK_TYPE_CODE
(
type_code_
,
kStr
);
return
std
::
string
(
value_
.
v_str
);
}
}
TVM_CHECK_TYPE_CODE
(
type_code_
,
kStr
);
return
std
::
string
(
value_
.
v_str
);
}
}
operator
TVMType
()
const
{
operator
TVMType
()
const
{
if
(
type_code_
==
kStr
)
{
if
(
type_code_
==
kStr
)
{
...
@@ -452,7 +462,8 @@ class TVMRetValue : public TVMPODValue_ {
...
@@ -452,7 +462,8 @@ class TVMRetValue : public TVMPODValue_ {
template
<
typename
T
>
template
<
typename
T
>
void
Assign
(
const
T
&
other
)
{
void
Assign
(
const
T
&
other
)
{
switch
(
other
.
type_code
())
{
switch
(
other
.
type_code
())
{
case
kStr
:
{
case
kStr
:
case
kBytes
:
{
SwitchToClass
<
std
::
string
>
(
kStr
,
other
);
SwitchToClass
<
std
::
string
>
(
kStr
,
other
);
break
;
break
;
}
}
...
...
python/tvm/_ctypes/_function.py
View file @
08505e34
# coding: utf-8
# coding: utf-8
# pylint: disable=invalid-name, protected-access
# pylint: disable=invalid-name, protected-access
, too-many-branches
"""Symbolic configuration API."""
"""Symbolic configuration API."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
...
@@ -9,7 +9,7 @@ from numbers import Number, Integral
...
@@ -9,7 +9,7 @@ from numbers import Number, Integral
from
.._base
import
_LIB
,
check_call
from
.._base
import
_LIB
,
check_call
from
.._base
import
c_str
,
py_str
,
string_types
from
.._base
import
c_str
,
py_str
,
string_types
from
._types
import
TVMValue
,
TypeCode
,
TVMType
from
._types
import
TVMValue
,
TypeCode
,
TVMType
,
TVMByteArray
from
._types
import
TVMPackedCFunc
,
TVMCFuncFinalizer
from
._types
import
TVMPackedCFunc
,
TVMCFuncFinalizer
from
._types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
from
._types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
from
._node
import
NodeBase
,
SliceBase
,
convert_to_node
from
._node
import
NodeBase
,
SliceBase
,
convert_to_node
...
@@ -92,6 +92,15 @@ def _make_tvm_args(args, temp_args):
...
@@ -92,6 +92,15 @@ def _make_tvm_args(args, temp_args):
elif
isinstance
(
arg
,
TVMType
):
elif
isinstance
(
arg
,
TVMType
):
values
[
i
]
.
v_str
=
c_str
(
str
(
arg
))
values
[
i
]
.
v_str
=
c_str
(
str
(
arg
))
type_codes
[
i
]
=
TypeCode
.
STR
type_codes
[
i
]
=
TypeCode
.
STR
elif
isinstance
(
arg
,
bytearray
):
arr
=
TVMByteArray
()
arr
.
data
=
ctypes
.
cast
(
(
ctypes
.
c_byte
*
len
(
arg
))
.
from_buffer
(
arg
),
ctypes
.
POINTER
(
ctypes
.
c_byte
))
arr
.
size
=
len
(
arg
)
values
[
i
]
.
v_handle
=
ctypes
.
c_void_p
(
ctypes
.
addressof
(
arr
))
temp_args
.
append
(
arr
)
type_codes
[
i
]
=
TypeCode
.
BYTES
elif
isinstance
(
arg
,
string_types
):
elif
isinstance
(
arg
,
string_types
):
values
[
i
]
.
v_str
=
c_str
(
arg
)
values
[
i
]
.
v_str
=
c_str
(
arg
)
type_codes
[
i
]
=
TypeCode
.
STR
type_codes
[
i
]
=
TypeCode
.
STR
...
...
python/tvm/_ctypes/_types.py
View file @
08505e34
...
@@ -18,8 +18,9 @@ class TypeCode(object):
...
@@ -18,8 +18,9 @@ class TypeCode(object):
ARRAY_HANDLE
=
5
ARRAY_HANDLE
=
5
TVM_TYPE
=
6
TVM_TYPE
=
6
NODE_HANDLE
=
7
NODE_HANDLE
=
7
STR
=
8
FUNC_HANDLE
=
8
FUNC_HANDLE
=
9
STR
=
9
BYTES
=
10
def
_api_type
(
code
):
def
_api_type
(
code
):
"""create a type accepted by API"""
"""create a type accepted by API"""
...
@@ -88,6 +89,11 @@ class TVMValue(ctypes.Union):
...
@@ -88,6 +89,11 @@ class TVMValue(ctypes.Union):
(
"v_handle"
,
ctypes
.
c_void_p
),
(
"v_handle"
,
ctypes
.
c_void_p
),
(
"v_str"
,
ctypes
.
c_char_p
)]
(
"v_str"
,
ctypes
.
c_char_p
)]
class
TVMByteArray
(
ctypes
.
Structure
):
"""TVM datatype structure"""
_fields_
=
[(
"data"
,
ctypes
.
POINTER
(
ctypes
.
c_byte
)),
(
"size"
,
ctypes
.
c_size_t
)]
TVMPackedCFunc
=
ctypes
.
CFUNCTYPE
(
TVMPackedCFunc
=
ctypes
.
CFUNCTYPE
(
None
,
None
,
...
@@ -110,20 +116,34 @@ def _return_handle(x):
...
@@ -110,20 +116,34 @@ def _return_handle(x):
handle
=
ctypes
.
c_void_p
(
handle
)
handle
=
ctypes
.
c_void_p
(
handle
)
return
handle
return
handle
def
_return_bytes
(
x
):
"""return handle"""
handle
=
x
.
v_handle
if
not
isinstance
(
handle
,
ctypes
.
c_void_p
):
handle
=
ctypes
.
c_void_p
(
handle
)
arr
=
ctypes
.
cast
(
handle
,
ctypes
.
POINTER
(
TVMByteArray
))[
0
]
size
=
arr
.
size
res
=
bytearray
(
size
)
rptr
=
(
ctypes
.
c_byte
*
size
)
.
from_buffer
(
res
)
if
not
ctypes
.
memmove
(
rptr
,
arr
.
data
,
size
):
raise
RuntimeError
(
'memmove failed'
)
return
res
RETURN_SWITCH
=
{
RETURN_SWITCH
=
{
TypeCode
.
INT
:
lambda
x
:
x
.
v_int64
,
TypeCode
.
INT
:
lambda
x
:
x
.
v_int64
,
TypeCode
.
FLOAT
:
lambda
x
:
x
.
v_float64
,
TypeCode
.
FLOAT
:
lambda
x
:
x
.
v_float64
,
TypeCode
.
HANDLE
:
_return_handle
,
TypeCode
.
HANDLE
:
_return_handle
,
TypeCode
.
NULL
:
lambda
x
:
None
,
TypeCode
.
NULL
:
lambda
x
:
None
,
TypeCode
.
STR
:
lambda
x
:
py_str
(
x
.
v_str
)
TypeCode
.
STR
:
lambda
x
:
py_str
(
x
.
v_str
),
TypeCode
.
BYTES
:
_return_bytes
}
}
C_TO_PY_ARG_SWITCH
=
{
C_TO_PY_ARG_SWITCH
=
{
TypeCode
.
INT
:
lambda
x
:
x
.
v_int64
,
TypeCode
.
INT
:
lambda
x
:
x
.
v_int64
,
TypeCode
.
FLOAT
:
lambda
x
:
x
.
v_float64
,
TypeCode
.
FLOAT
:
lambda
x
:
x
.
v_float64
,
TypeCode
.
HANDLE
:
_return_handle
,
TypeCode
.
HANDLE
:
_return_handle
,
TypeCode
.
NULL
:
lambda
x
:
None
,
TypeCode
.
NULL
:
lambda
x
:
None
,
TypeCode
.
STR
:
lambda
x
:
py_str
(
x
.
v_str
)
TypeCode
.
STR
:
lambda
x
:
py_str
(
x
.
v_str
),
TypeCode
.
BYTES
:
_return_bytes
}
}
python/tvm/addon/__init__.py
0 → 100644
View file @
08505e34
"""Addon utilities to python"""
python/tvm/addon/nvcc_compiler.py
0 → 100644
View file @
08505e34
"""Util to compile with NVCC"""
import
os
import
sys
import
tempfile
import
subprocess
def
compile_source
(
code
,
target
=
"cubin"
):
"""Compile cuda code with NVCC from env.
Parameters
----------
code : str
The cuda code.
target: str
The target format
Return
------
cubin : bytearray
The bytearray of the cubin
"""
temp_dir
=
tempfile
.
mkdtemp
()
if
target
not
in
[
"cubin"
,
"ptx"
,
"fatbin"
]:
raise
ValueError
(
"target must be in cubin, ptx, fatbin"
)
path_code
=
os
.
path
.
join
(
temp_dir
,
"my_kernel.cu"
)
path_target
=
os
.
path
.
join
(
temp_dir
,
"my_kernel.
%
s"
%
target
)
with
open
(
path_code
,
"w"
)
as
out_file
:
out_file
.
write
(
code
)
cmd
=
[
"nvcc"
]
cmd
+=
[
"--
%
s"
%
target
,
"-O3"
]
cmd
+=
[
"-o"
,
path_target
]
cmd
+=
[
path_code
]
args
=
' '
.
join
(
cmd
)
proc
=
subprocess
.
Popen
(
args
,
shell
=
True
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
)
(
out
,
_
)
=
proc
.
communicate
()
if
proc
.
returncode
!=
0
:
sys
.
stderr
.
write
(
"Compilation error:
\n
"
)
sys
.
stderr
.
write
(
out
)
sys
.
stderr
.
flush
()
cubin
=
None
else
:
cubin
=
bytearray
(
open
(
path_target
,
"rb"
)
.
read
())
os
.
remove
(
path_code
)
if
os
.
path
.
exists
(
path_target
):
os
.
remove
(
path_target
)
os
.
rmdir
(
temp_dir
)
return
cubin
src/arithmetic/canonical.cc
View file @
08505e34
...
@@ -158,7 +158,8 @@ class Canonical::Internal : public IRMutator {
...
@@ -158,7 +158,8 @@ class Canonical::Internal : public IRMutator {
}
}
// functions
// functions
Stmt
Mutate
(
Stmt
stmt
)
final
{
Stmt
Mutate
(
Stmt
stmt
)
final
{
return
IRMutator
::
Mutate
(
stmt
);
stmt
=
IRMutator
::
Mutate
(
stmt
);
return
stmt
;
}
}
Expr
MutateExpr_
(
Expr
expr
)
{
Expr
MutateExpr_
(
Expr
expr
)
{
static
const
FMutateExpr
&
f
=
Internal
::
vtable_expr
();
static
const
FMutateExpr
&
f
=
Internal
::
vtable_expr
();
...
@@ -176,6 +177,7 @@ class Canonical::Internal : public IRMutator {
...
@@ -176,6 +177,7 @@ class Canonical::Internal : public IRMutator {
ret_entry_
.
has_side_effect
=
stack_
.
back
().
has_side_effect
;
ret_entry_
.
has_side_effect
=
stack_
.
back
().
has_side_effect
;
ret_entry_
.
max_level
=
stack_
.
back
().
max_level
;
ret_entry_
.
max_level
=
stack_
.
back
().
max_level
;
stack_
.
pop_back
();
stack_
.
pop_back
();
CHECK
(
expr
.
defined
());
return
expr
;
return
expr
;
}
}
// call produce to get a cache entry.
// call produce to get a cache entry.
...
@@ -399,6 +401,7 @@ class Canonical::Internal : public IRMutator {
...
@@ -399,6 +401,7 @@ class Canonical::Internal : public IRMutator {
// subroutine to do produce
// subroutine to do produce
Expr
SumAdd
(
CacheEntry
a
,
CacheEntry
b
,
int
bscale
)
{
Expr
SumAdd
(
CacheEntry
a
,
CacheEntry
b
,
int
bscale
)
{
ret_entry_
.
sum
=
SumAdd_
(
a
.
AsSum
(),
b
.
AsSum
(),
bscale
);
ret_entry_
.
sum
=
SumAdd_
(
a
.
AsSum
(),
b
.
AsSum
(),
bscale
);
CHECK_NE
(
stack_
.
size
(),
0U
);
ret_entry_
.
max_level
=
stack_
.
back
().
max_level
;
ret_entry_
.
max_level
=
stack_
.
back
().
max_level
;
ret_entry_
.
has_side_effect
=
stack_
.
back
().
has_side_effect
;
ret_entry_
.
has_side_effect
=
stack_
.
back
().
has_side_effect
;
auto
it
=
cache_sum_
.
find
(
ret_entry_
.
sum
);
auto
it
=
cache_sum_
.
find
(
ret_entry_
.
sum
);
...
@@ -408,8 +411,6 @@ class Canonical::Internal : public IRMutator {
...
@@ -408,8 +411,6 @@ class Canonical::Internal : public IRMutator {
ret_entry_
.
value
=
Sum2Expr
(
ret_entry_
.
sum
,
a
.
value
.
type
());
ret_entry_
.
value
=
Sum2Expr
(
ret_entry_
.
sum
,
a
.
value
.
type
());
cache_sum_
[
ret_entry_
.
sum
]
=
ret_entry_
;
cache_sum_
[
ret_entry_
.
sum
]
=
ret_entry_
;
}
}
ret_entry_
.
value
=
Sum2Expr
(
ret_entry_
.
sum
,
a
.
value
.
type
());
cache_sum_
[
ret_entry_
.
sum
]
=
ret_entry_
;
return
ret_entry_
.
value
;
return
ret_entry_
.
value
;
}
}
// convert sum to expr
// convert sum to expr
...
@@ -444,7 +445,11 @@ class Canonical::Internal : public IRMutator {
...
@@ -444,7 +445,11 @@ class Canonical::Internal : public IRMutator {
}
}
}
}
}
}
return
vsum
;
if
(
vsum
.
defined
())
{
return
vsum
;
}
else
{
return
make_zero
(
t
);
}
}
}
};
};
...
...
src/codegen/codegen_cuda.cc
View file @
08505e34
...
@@ -50,7 +50,19 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
...
@@ -50,7 +50,19 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
os
<<
CodeGenCUDA
().
Compile
(
f
,
output_ssa
);
os
<<
CodeGenCUDA
().
Compile
(
f
,
output_ssa
);
os
<<
'\n'
;
os
<<
'\n'
;
}
}
std
::
string
ptx
=
runtime
::
NVRTCCompile
(
os
.
str
());
std
::
string
code
=
os
.
str
();
if
(
PackedFunc
::
GlobalExist
(
"tvm_callback_cuda_postproc"
))
{
const
auto
&
f
=
PackedFunc
::
GetGlobal
(
"tvm_callback_cuda_postproc"
);
code
=
f
(
code
).
operator
std
::
string
();
}
std
::
string
ptx
;
if
(
PackedFunc
::
GlobalExist
(
"tvm_callback_cuda_compile"
))
{
const
auto
&
f
=
PackedFunc
::
GetGlobal
(
"tvm_callback_cuda_compile"
);
ptx
=
f
(
code
).
operator
std
::
string
();
}
else
{
ptx
=
runtime
::
NVRTCCompile
(
os
.
str
());
}
std
::
unordered_map
<
LoweredFunc
,
PackedFunc
>
ret
;
std
::
unordered_map
<
LoweredFunc
,
PackedFunc
>
ret
;
runtime
::
CUDAModule
m
=
runtime
::
CUDAModule
::
Create
(
ptx
);
runtime
::
CUDAModule
m
=
runtime
::
CUDAModule
::
Create
(
ptx
);
...
...
src/runtime/packed_func_registry.cc
View file @
08505e34
...
@@ -46,6 +46,12 @@ const PackedFunc& PackedFunc::GetGlobal(const std::string& name) {
...
@@ -46,6 +46,12 @@ const PackedFunc& PackedFunc::GetGlobal(const std::string& name) {
return
*
(
it
->
second
);
return
*
(
it
->
second
);
}
}
bool
PackedFunc
::
GlobalExist
(
const
std
::
string
&
name
)
{
PackedFuncRegistry
*
r
=
PackedFuncRegistry
::
Global
();
auto
it
=
r
->
fmap
.
find
(
name
);
return
it
!=
r
->
fmap
.
end
();
}
std
::
vector
<
std
::
string
>
PackedFunc
::
ListGlobalNames
()
{
std
::
vector
<
std
::
string
>
PackedFunc
::
ListGlobalNames
()
{
PackedFuncRegistry
*
r
=
PackedFuncRegistry
::
Global
();
PackedFuncRegistry
*
r
=
PackedFuncRegistry
::
Global
();
std
::
vector
<
std
::
string
>
keys
;
std
::
vector
<
std
::
string
>
keys
;
...
...
tests/python/integration/test_gemm.py
View file @
08505e34
import
tvm
import
tvm
from
tvm.addon
import
nvcc_compiler
import
numpy
as
np
import
numpy
as
np
@tvm.register_func
def
tvm_callback_cuda_compile
(
code
):
ptx
=
nvcc_compiler
.
compile_source
(
code
,
target
=
"ptx"
)
print
(
ptx
.
decode
(
"utf-8"
))
return
ptx
@tvm.register_func
def
tvm_callback_cuda_postproc
(
code
):
print
(
code
)
return
code
def
test_gemm
():
def
test_gemm
():
# graph
# graph
nn
=
1024
nn
=
1024
...
@@ -23,7 +35,6 @@ def test_gemm():
...
@@ -23,7 +35,6 @@ def test_gemm():
s
=
tvm
.
Schedule
(
C
.
op
)
s
=
tvm
.
Schedule
(
C
.
op
)
xtile
,
ytile
=
32
,
32
xtile
,
ytile
=
32
,
32
s
[
AA
]
.
set_scope
(
"shared"
)
s
[
AA
]
.
set_scope
(
"shared"
)
#s[CC].set_scope("global")
s
[
BB
]
.
set_scope
(
"shared"
)
s
[
BB
]
.
set_scope
(
"shared"
)
scale
=
8
scale
=
8
...
@@ -60,8 +71,6 @@ def test_gemm():
...
@@ -60,8 +71,6 @@ def test_gemm():
codes
=
[]
codes
=
[]
f
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
target
,
record_codes
=
codes
,
f
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
target
,
record_codes
=
codes
,
max_auto_unroll_step
=
max_auto_unroll_step
)
max_auto_unroll_step
=
max_auto_unroll_step
)
for
c
in
codes
[
1
:]:
print
(
c
)
if
target
==
"cuda"
:
if
target
==
"cuda"
:
ctx
=
tvm
.
gpu
(
0
)
ctx
=
tvm
.
gpu
(
0
)
else
:
else
:
...
@@ -77,13 +86,14 @@ def test_gemm():
...
@@ -77,13 +86,14 @@ def test_gemm():
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
a
=
tvm
.
nd
.
array
(
a_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
b_np
,
ctx
)
b
=
tvm
.
nd
.
array
(
b_np
,
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
C
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
C
.
dtype
),
ctx
)
f
(
a
,
b
,
c
)
for
i
in
range
(
4
):
f
(
a
,
b
,
c
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
np
.
dot
(
a_np
,
b_np
.
T
),
rtol
=
1e-5
)
c
.
asnumpy
(),
np
.
dot
(
a_np
,
b_np
.
T
),
rtol
=
1e-5
)
tvm
.
init_opencl
()
check_device
(
"cuda"
)
check_device
(
"cuda"
)
check_device
(
"opencl"
)
#tvm.init_opencl()
#check_device("opencl")
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_gemm
()
test_gemm
()
tests/python/unittest/test_runtime_packed_func.py
View file @
08505e34
...
@@ -35,9 +35,17 @@ def test_convert():
...
@@ -35,9 +35,17 @@ def test_convert():
assert
isinstance
(
f
,
tvm
.
nd
.
Function
)
assert
isinstance
(
f
,
tvm
.
nd
.
Function
)
f
(
*
targs
)
f
(
*
targs
)
def
test_byte_array
():
s
=
"hello"
a
=
bytearray
(
s
,
encoding
=
"ascii"
)
def
myfunc
(
ss
):
assert
ss
==
a
f
=
tvm
.
convert
(
myfunc
)
f
(
a
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_function
()
test_convert
()
test_convert
()
test_get_global
()
test_get_global
()
test_return_func
()
test_return_func
()
test_byte_array
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment