Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
8214d6ca
Commit
8214d6ca
authored
Nov 03, 2017
by
Tianqi Chen
Committed by
GitHub
Nov 03, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[DLPack] Upgrade dlpack to 0.2 (#609)
parent
a152a9cb
Hide whitespace changes
Inline
Side-by-side
Showing
31 changed files
with
122 additions
and
121 deletions
+122
-121
apps/howto_deploy/cpp_deploy.cc
+2
-2
dlpack
+1
-1
include/tvm/packed_func_ext.h
+2
-2
include/tvm/runtime/packed_func.h
+18
-18
jvm/native/src/main/native/jni_helper_func.h
+3
-3
jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc
+2
-2
src/api/api_lang.cc
+2
-2
src/codegen/llvm/codegen_amdgpu.cc
+1
-1
src/codegen/llvm/codegen_llvm.cc
+1
-1
src/codegen/llvm/codegen_nvptx.cc
+1
-1
src/codegen/stack_vm/stack_vm.h
+6
-6
src/codegen/verilog/vpi_device_api.cc
+2
-2
src/contrib/cblas/cblas.cc
+3
-3
src/contrib/cudnn/cudnn_utils.cc
+3
-3
src/contrib/nnpack/convolution.cc
+8
-8
src/contrib/nnpack/fully_connected.cc
+6
-6
src/pass/lower_tvm_builtin.cc
+1
-1
src/pass/make_api.cc
+4
-3
src/runtime/c_runtime_api.cc
+11
-11
src/runtime/cpu_device_api.cc
+1
-1
src/runtime/cuda/cuda_device_api.cc
+4
-4
src/runtime/graph/graph_runtime.cc
+1
-1
src/runtime/metal/metal_common.h
+4
-4
src/runtime/metal/metal_device_api.mm
+4
-4
src/runtime/opencl/opencl_common.h
+3
-3
src/runtime/opencl/opencl_device_api.cc
+3
-3
src/runtime/pack_args.h
+3
-3
src/runtime/rocm/rocm_device_api.cc
+4
-4
src/runtime/rpc/rpc_device_api.cc
+2
-2
src/runtime/rpc/rpc_session.cc
+15
-15
tests/cpp/packed_func_test.cc
+1
-1
No files found.
apps/howto_deploy/cpp_deploy.cc
View file @
8214d6ca
...
...
@@ -28,10 +28,10 @@ void Verify(tvm::runtime::Module mod, std::string fname) {
DLTensor
*
x
;
DLTensor
*
y
;
int
ndim
=
1
;
int
dtype_code
=
kFloat
;
int
dtype_code
=
k
DL
Float
;
int
dtype_bits
=
32
;
int
dtype_lanes
=
1
;
int
device_type
=
kCPU
;
int
device_type
=
k
DL
CPU
;
int
device_id
=
0
;
int64_t
shape
[
1
]
=
{
10
};
TVMArrayAlloc
(
shape
,
ndim
,
dtype_code
,
dtype_bits
,
dtype_lanes
,
...
...
dlpack
@
10892ac9
Subproject commit
9422e98f3f4dafc6bc3473cf8484543ad376aab6
Subproject commit
10892ac964f1af7c81aae145cd3fab78bbccd297
include/tvm/packed_func_ext.h
View file @
8214d6ca
...
...
@@ -105,10 +105,10 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
inline
TVMArgValue
::
operator
Halide
::
Expr
()
const
{
if
(
type_code_
==
kNull
)
return
Expr
();
if
(
type_code_
==
kInt
)
{
if
(
type_code_
==
k
DL
Int
)
{
return
Expr
(
static_cast
<
int
>
(
value_
.
v_int64
));
}
if
(
type_code_
==
kFloat
)
{
if
(
type_code_
==
k
DL
Float
)
{
return
Expr
(
static_cast
<
float
>
(
value_
.
v_float64
));
}
TVM_CHECK_TYPE_CODE
(
type_code_
,
kNodeHandle
);
...
...
include/tvm/runtime/packed_func.h
View file @
8214d6ca
...
...
@@ -217,25 +217,25 @@ class ExtTypeVTable {
class
TVMPODValue_
{
public
:
operator
double
()
const
{
TVM_CHECK_TYPE_CODE
(
type_code_
,
kFloat
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
k
DL
Float
);
return
value_
.
v_float64
;
}
operator
int64_t
()
const
{
TVM_CHECK_TYPE_CODE
(
type_code_
,
kInt
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
k
DL
Int
);
return
value_
.
v_int64
;
}
operator
uint64_t
()
const
{
TVM_CHECK_TYPE_CODE
(
type_code_
,
kInt
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
k
DL
Int
);
return
value_
.
v_int64
;
}
operator
int
()
const
{
TVM_CHECK_TYPE_CODE
(
type_code_
,
kInt
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
k
DL
Int
);
CHECK_LE
(
value_
.
v_int64
,
std
::
numeric_limits
<
int
>::
max
());
return
static_cast
<
int
>
(
value_
.
v_int64
);
}
operator
bool
()
const
{
TVM_CHECK_TYPE_CODE
(
type_code_
,
kInt
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
k
DL
Int
);
return
value_
.
v_int64
!=
0
;
}
operator
void
*
()
const
{
...
...
@@ -430,7 +430,7 @@ class TVMRetValue : public TVMPODValue_ {
return
*
this
;
}
TVMRetValue
&
operator
=
(
double
value
)
{
this
->
SwitchToPOD
(
kFloat
);
this
->
SwitchToPOD
(
k
DL
Float
);
value_
.
v_float64
=
value
;
return
*
this
;
}
...
...
@@ -445,12 +445,12 @@ class TVMRetValue : public TVMPODValue_ {
return
*
this
;
}
TVMRetValue
&
operator
=
(
int64_t
value
)
{
this
->
SwitchToPOD
(
kInt
);
this
->
SwitchToPOD
(
k
DL
Int
);
value_
.
v_int64
=
value
;
return
*
this
;
}
TVMRetValue
&
operator
=
(
int
value
)
{
this
->
SwitchToPOD
(
kInt
);
this
->
SwitchToPOD
(
k
DL
Int
);
value_
.
v_int64
=
value
;
return
*
this
;
}
...
...
@@ -460,7 +460,7 @@ class TVMRetValue : public TVMPODValue_ {
return
*
this
;
}
TVMRetValue
&
operator
=
(
bool
value
)
{
this
->
SwitchToPOD
(
kInt
);
this
->
SwitchToPOD
(
k
DL
Int
);
value_
.
v_int64
=
value
;
return
*
this
;
}
...
...
@@ -609,9 +609,9 @@ class TVMRetValue : public TVMPODValue_ {
// implementation details
inline
const
char
*
TypeCode2Str
(
int
type_code
)
{
switch
(
type_code
)
{
case
kInt
:
return
"int"
;
case
kUInt
:
return
"uint"
;
case
kFloat
:
return
"float"
;
case
k
DL
Int
:
return
"int"
;
case
k
DL
UInt
:
return
"uint"
;
case
k
DL
Float
:
return
"float"
;
case
kStr
:
return
"str"
;
case
kBytes
:
return
"bytes"
;
case
kHandle
:
return
"handle"
;
...
...
@@ -648,11 +648,11 @@ inline TVMType String2TVMType(std::string s) {
t
.
bits
=
32
;
t
.
lanes
=
1
;
const
char
*
scan
;
if
(
s
.
substr
(
0
,
3
)
==
"int"
)
{
t
.
code
=
kInt
;
scan
=
s
.
c_str
()
+
3
;
t
.
code
=
k
DL
Int
;
scan
=
s
.
c_str
()
+
3
;
}
else
if
(
s
.
substr
(
0
,
4
)
==
"uint"
)
{
t
.
code
=
kUInt
;
scan
=
s
.
c_str
()
+
4
;
t
.
code
=
k
DL
UInt
;
scan
=
s
.
c_str
()
+
4
;
}
else
if
(
s
.
substr
(
0
,
5
)
==
"float"
)
{
t
.
code
=
kFloat
;
scan
=
s
.
c_str
()
+
5
;
t
.
code
=
k
DL
Float
;
scan
=
s
.
c_str
()
+
5
;
}
else
if
(
s
.
substr
(
0
,
6
)
==
"handle"
)
{
t
.
code
=
kHandle
;
t
.
bits
=
64
;
// handle uses 64 bit by default.
...
...
@@ -724,17 +724,17 @@ class TVMArgsSetter {
std
::
is_integral
<
T
>::
value
>::
type
>
void
operator
()(
size_t
i
,
T
value
)
const
{
values_
[
i
].
v_int64
=
static_cast
<
int64_t
>
(
value
);
type_codes_
[
i
]
=
kInt
;
type_codes_
[
i
]
=
k
DL
Int
;
}
void
operator
()(
size_t
i
,
uint64_t
value
)
const
{
values_
[
i
].
v_int64
=
static_cast
<
int64_t
>
(
value
);
CHECK_LE
(
value
,
static_cast
<
uint64_t
>
(
std
::
numeric_limits
<
int64_t
>::
max
()));
type_codes_
[
i
]
=
kInt
;
type_codes_
[
i
]
=
k
DL
Int
;
}
void
operator
()(
size_t
i
,
double
value
)
const
{
values_
[
i
].
v_float64
=
value
;
type_codes_
[
i
]
=
kFloat
;
type_codes_
[
i
]
=
k
DL
Float
;
}
void
operator
()(
size_t
i
,
std
::
nullptr_t
value
)
const
{
values_
[
i
].
v_handle
=
value
;
...
...
jvm/native/src/main/native/jni_helper_func.h
View file @
8214d6ca
...
...
@@ -161,10 +161,10 @@ void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) {
jobject
tvmRetValueToJava
(
JNIEnv
*
env
,
TVMValue
value
,
int
tcode
)
{
switch
(
tcode
)
{
case
kUInt
:
case
kInt
:
case
k
DL
UInt
:
case
k
DL
Int
:
return
newTVMValueLong
(
env
,
static_cast
<
jlong
>
(
value
.
v_int64
));
case
kFloat
:
case
k
DL
Float
:
return
newTVMValueDouble
(
env
,
static_cast
<
jdouble
>
(
value
.
v_float64
));
case
kModuleHandle
:
return
newModule
(
env
,
reinterpret_cast
<
jlong
>
(
value
.
v_handle
));
...
...
jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc
View file @
8214d6ca
...
...
@@ -62,7 +62,7 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgLong(
value
.
v_int64
=
static_cast
<
int64_t
>
(
arg
);
TVMFuncArgsThreadLocalEntry
*
e
=
TVMFuncArgsThreadLocalStore
::
Get
();
e
->
tvmFuncArgValues
.
push_back
(
value
);
e
->
tvmFuncArgTypes
.
push_back
(
kInt
);
e
->
tvmFuncArgTypes
.
push_back
(
k
DL
Int
);
}
JNIEXPORT
void
JNICALL
Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble
(
...
...
@@ -71,7 +71,7 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble(
value
.
v_float64
=
static_cast
<
double
>
(
arg
);
TVMFuncArgsThreadLocalEntry
*
e
=
TVMFuncArgsThreadLocalStore
::
Get
();
e
->
tvmFuncArgValues
.
push_back
(
value
);
e
->
tvmFuncArgTypes
.
push_back
(
kFloat
);
e
->
tvmFuncArgTypes
.
push_back
(
k
DL
Float
);
}
JNIEXPORT
void
JNICALL
Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgString
(
...
...
src/api/api_lang.cc
View file @
8214d6ca
...
...
@@ -27,9 +27,9 @@ TVM_REGISTER_API("_max_value")
TVM_REGISTER_API
(
"_const"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
[
0
].
type_code
()
==
kInt
)
{
if
(
args
[
0
].
type_code
()
==
k
DL
Int
)
{
*
ret
=
make_const
(
args
[
1
],
args
[
0
].
operator
int64_t
());
}
else
if
(
args
[
0
].
type_code
()
==
kFloat
)
{
}
else
if
(
args
[
0
].
type_code
()
==
k
DL
Float
)
{
*
ret
=
make_const
(
args
[
1
],
args
[
0
].
operator
double
());
}
else
{
LOG
(
FATAL
)
<<
"only accept int or float"
;
...
...
src/codegen/llvm/codegen_amdgpu.cc
View file @
8214d6ca
...
...
@@ -133,7 +133,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
inline
int
DetectROCMComputeVersion
()
{
TVMContext
tvm_ctx
;
tvm_ctx
.
device_type
=
kROCM
;
tvm_ctx
.
device_type
=
k
DL
ROCM
;
tvm_ctx
.
device_id
=
0
;
TVMRetValue
val
;
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
)
->
GetAttr
(
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
8214d6ca
...
...
@@ -242,7 +242,7 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
CHECK_EQ
(
t
.
lanes
(),
1
);
return
t_void_p_
;
}
llvm
::
Type
*
etype
;
llvm
::
Type
*
etype
=
nullptr
;
if
(
t
.
is_int
()
||
t
.
is_uint
())
{
etype
=
llvm
::
Type
::
getIntNTy
(
*
ctx_
,
t
.
bits
());
}
else
if
(
t
.
is_float
())
{
...
...
src/codegen/llvm/codegen_nvptx.cc
View file @
8214d6ca
...
...
@@ -132,7 +132,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
inline
int
DetectCUDAComputeVersion
()
{
TVMContext
tvm_ctx
;
tvm_ctx
.
device_type
=
kGPU
;
tvm_ctx
.
device_type
=
k
DL
GPU
;
tvm_ctx
.
device_id
=
0
;
TVMRetValue
val
;
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
)
->
GetAttr
(
...
...
src/codegen/stack_vm/stack_vm.h
View file @
8214d6ca
...
...
@@ -340,16 +340,16 @@ class StackVM {
static
OpCode
GetLoad
(
TVMType
t
)
{
CHECK_EQ
(
t
.
lanes
,
1U
);
if
(
t
.
code
==
kHandle
)
return
ARRAY_LOAD_HANDLE
;
if
(
t
.
code
==
kInt
)
{
if
(
t
.
code
==
k
DL
Int
)
{
switch
(
t
.
bits
)
{
case
32
:
return
ARRAY_LOAD_INT32
;
case
64
:
return
ARRAY_LOAD_INT64
;
}
}
else
if
(
t
.
code
==
kUInt
)
{
}
else
if
(
t
.
code
==
k
DL
UInt
)
{
switch
(
t
.
bits
)
{
case
32
:
return
ARRAY_LOAD_UINT32
;
}
}
else
if
(
t
.
code
==
kFloat
)
{
}
else
if
(
t
.
code
==
k
DL
Float
)
{
switch
(
t
.
bits
)
{
case
64
:
return
ARRAY_LOAD_FP64
;
}
...
...
@@ -365,16 +365,16 @@ class StackVM {
static
OpCode
GetStore
(
TVMType
t
)
{
CHECK_EQ
(
t
.
lanes
,
1U
);
if
(
t
.
code
==
kHandle
)
return
ARRAY_STORE_HANDLE
;
if
(
t
.
code
==
kInt
)
{
if
(
t
.
code
==
k
DL
Int
)
{
switch
(
t
.
bits
)
{
case
32
:
return
ARRAY_STORE_INT32
;
case
64
:
return
ARRAY_STORE_INT64
;
}
}
else
if
(
t
.
code
==
kUInt
)
{
}
else
if
(
t
.
code
==
k
DL
UInt
)
{
switch
(
t
.
bits
)
{
case
32
:
return
ARRAY_STORE_UINT32
;
}
}
else
if
(
t
.
code
==
kFloat
)
{
}
else
if
(
t
.
code
==
k
DL
Float
)
{
switch
(
t
.
bits
)
{
case
64
:
return
ARRAY_STORE_FP64
;
}
...
...
src/codegen/verilog/vpi_device_api.cc
View file @
8214d6ca
...
...
@@ -91,10 +91,10 @@ class VPIDeviceAPI final : public runtime::DeviceAPI {
TVMContext
ctx_from
,
TVMContext
ctx_to
,
TVMStreamHandle
stream
)
final
{
if
(
static_cast
<
int
>
(
ctx_from
.
device_type
)
==
kVPI
)
{
if
(
static_cast
<
int
>
(
ctx_from
.
device_type
)
==
k
DL
VPI
)
{
from
=
RealAddr
(
static_cast
<
const
char
*>
(
from
)
+
from_offset
,
size
);
}
if
(
static_cast
<
int
>
(
ctx_to
.
device_type
)
==
kVPI
)
{
if
(
static_cast
<
int
>
(
ctx_to
.
device_type
)
==
k
DL
VPI
)
{
to
=
RealAddr
(
static_cast
<
char
*>
(
to
)
+
to_offset
,
size
);
}
memcpy
(
to
,
from
,
size
);
...
...
src/contrib/cblas/cblas.cc
View file @
8214d6ca
...
...
@@ -30,9 +30,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
CHECK
(
C
->
strides
==
nullptr
);
CHECK
(
B
->
strides
==
nullptr
);
CHECK
(
A
->
strides
==
nullptr
);
CHECK
(
TypeMatch
(
A
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
B
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
C
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
A
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
B
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
C
->
dtype
,
k
DL
Float
,
32
));
cblas_sgemm
(
CblasColMajor
,
transb
?
CblasTrans
:
CblasNoTrans
,
transa
?
CblasTrans
:
CblasNoTrans
,
...
...
src/contrib/cudnn/cudnn_utils.cc
View file @
8214d6ca
...
...
@@ -13,17 +13,17 @@ namespace contrib {
// CuDNN Data Type
cudnnDataType_t
CuDNNDataType
::
DLTypeToCuDNNType
(
const
DLDataType
&
dtype
)
{
switch
(
dtype
.
code
)
{
case
kInt
:
case
k
DL
Int
:
if
(
dtype
.
bits
==
8
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_INT8
;
else
if
(
dtype
.
bits
==
32
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_INT32
;
else
if
(
dtype
.
bits
==
8
&&
dtype
.
lanes
==
4
)
return
CUDNN_DATA_INT8x4
;
else
LOG
(
FATAL
)
<<
"Unsupported type"
;
break
;
case
kUInt
:
case
k
DL
UInt
:
LOG
(
FATAL
)
<<
"Unsupported type"
;
break
;
case
kFloat
:
case
k
DL
Float
:
if
(
dtype
.
bits
==
32
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_FLOAT
;
else
if
(
dtype
.
bits
==
64
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_DOUBLE
;
else
if
(
dtype
.
bits
==
16
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_HALF
;
...
...
src/contrib/nnpack/convolution.cc
View file @
8214d6ca
...
...
@@ -44,10 +44,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
CHECK
(
kernel
->
strides
==
nullptr
);
CHECK
(
bias
->
strides
==
nullptr
);
CHECK
(
TypeMatch
(
input
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
kernel
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
bias
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
output
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
input
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
kernel
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
bias
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
output
->
dtype
,
k
DL
Float
,
32
));
nnp_convolution_inference
(
nnp_convolution_algorithm_auto
,
nnp_convolution_transform_strategy_block_based
,
...
...
@@ -102,10 +102,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
CHECK
(
kernel
->
strides
==
nullptr
);
CHECK
(
bias
->
strides
==
nullptr
);
CHECK
(
TypeMatch
(
input
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
kernel
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
bias
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
output
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
input
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
kernel
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
bias
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
output
->
dtype
,
k
DL
Float
,
32
));
nnp_convolution_output
(
nnp_convolution_algorithm_auto
,
batch_size
,
...
...
src/contrib/nnpack/fully_connected.cc
View file @
8214d6ca
...
...
@@ -29,9 +29,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
CHECK
(
C
->
strides
==
nullptr
);
CHECK
(
B
->
strides
==
nullptr
);
CHECK
(
A
->
strides
==
nullptr
);
CHECK
(
TypeMatch
(
A
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
B
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
C
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
A
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
B
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
C
->
dtype
,
k
DL
Float
,
32
));
nnp_fully_connected_inference
(
B
->
shape
[
1
],
B
->
shape
[
0
],
...
...
@@ -58,9 +58,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
CHECK
(
C
->
strides
==
nullptr
);
CHECK
(
B
->
strides
==
nullptr
);
CHECK
(
A
->
strides
==
nullptr
);
CHECK
(
TypeMatch
(
A
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
B
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
C
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
A
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
B
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
C
->
dtype
,
k
DL
Float
,
32
));
nnp_fully_connected_output
(
A
->
shape
[
0
],
B
->
shape
[
1
],
...
...
src/pass/lower_tvm_builtin.cc
View file @
8214d6ca
...
...
@@ -72,7 +72,7 @@ class BuiltinLower : public IRMutator {
int64_t
nbytes
=
GetVectorBytes
(
op
->
type
);
if
(
device_type_
.
defined
())
{
if
(
arith
::
GetConst
(
device_type_
,
&
dev_type
))
{
if
(
dev_type
==
kCPU
)
{
if
(
dev_type
==
k
DL
CPU
)
{
int32_t
constant_size
=
op
->
constant_allocation_size
();
if
(
constant_size
>
0
&&
constant_size
*
nbytes
<
runtime
::
kMaxStackAlloca
)
{
return
stmt
;
...
...
src/pass/make_api.cc
View file @
8214d6ca
...
...
@@ -107,12 +107,13 @@ LoweredFunc MakeAPI(Stmt body,
}
else
if
(
t
.
is_int
()
||
t
.
is_uint
())
{
std
::
ostringstream
msg
;
msg
<<
name
<<
": Expect arg["
<<
i
<<
"] to be int"
;
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kInt
,
msg
.
str
(),
nop
));
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
k
DL
Int
,
msg
.
str
(),
nop
));
}
else
{
CHECK
(
t
.
is_float
());
std
::
ostringstream
msg
;
msg
<<
name
<<
": Expect arg["
<<
i
<<
"] to be float"
;
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kFloat
,
msg
.
str
(),
nop
));
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kDLFloat
,
msg
.
str
(),
nop
));
}
}
else
{
args
.
push_back
(
v_arg
);
...
...
@@ -148,7 +149,7 @@ LoweredFunc MakeAPI(Stmt body,
seq_check
.
push_back
(
AttrStmt
::
make
(
node
,
attr
::
device_context_type
,
device_type
,
nop
));
Stmt
set_device
=
IfThenElse
::
make
(
device_type
!=
kCPU
,
Evaluate
::
make
(
Call
::
make
(
device_type
!=
k
DL
CPU
,
Evaluate
::
make
(
Call
::
make
(
Int
(
32
),
intrinsic
::
tvm_call_packed
,
{
StringImm
::
make
(
runtime
::
symbol
::
tvm_set_device
),
device_type
,
device_id
},
Call
::
Intrinsic
)));
...
...
src/runtime/c_runtime_api.cc
View file @
8214d6ca
...
...
@@ -25,12 +25,12 @@ namespace runtime {
*/
inline
std
::
string
DeviceName
(
int
type
)
{
switch
(
type
)
{
case
kCPU
:
return
"cpu"
;
case
kGPU
:
return
"gpu"
;
case
kOpenCL
:
return
"opencl"
;
case
kMetal
:
return
"metal"
;
case
kVPI
:
return
"vpi"
;
case
kROCM
:
return
"rocm"
;
case
k
DL
CPU
:
return
"cpu"
;
case
k
DL
GPU
:
return
"gpu"
;
case
k
DL
OpenCL
:
return
"opencl"
;
case
k
DL
Metal
:
return
"metal"
;
case
k
DL
VPI
:
return
"vpi"
;
case
k
DL
ROCM
:
return
"rocm"
;
case
kExtDev
:
return
"ext_dev"
;
default
:
LOG
(
FATAL
)
<<
"unknown type ="
<<
type
;
return
"Unknown"
;
}
...
...
@@ -126,7 +126,7 @@ inline void TVMArrayFree_(TVMArray* arr) {
inline
void
VerifyType
(
int
dtype_code
,
int
dtype_bits
,
int
dtype_lanes
)
{
CHECK_GE
(
dtype_lanes
,
1
);
if
(
dtype_code
==
kFloat
)
{
if
(
dtype_code
==
k
DL
Float
)
{
CHECK_EQ
(
dtype_bits
%
32
,
0
);
}
else
{
CHECK_EQ
(
dtype_bits
%
8
,
0
);
...
...
@@ -382,10 +382,10 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
CHECK_EQ
(
from_size
,
to_size
)
<<
"TVMArrayCopyFromTo: The size must exactly match"
;
TVMContext
ctx
=
from
->
ctx
;
if
(
ctx
.
device_type
==
kCPU
)
{
if
(
ctx
.
device_type
==
k
DL
CPU
)
{
ctx
=
to
->
ctx
;
}
else
{
CHECK
(
to
->
ctx
.
device_type
==
kCPU
||
CHECK
(
to
->
ctx
.
device_type
==
k
DL
CPU
||
to
->
ctx
.
device_type
==
from
->
ctx
.
device_type
)
<<
"Can not copy across different ctx types directly"
;
}
...
...
@@ -401,7 +401,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle,
size_t
nbytes
)
{
API_BEGIN
();
TVMContext
cpu_ctx
;
cpu_ctx
.
device_type
=
kCPU
;
cpu_ctx
.
device_type
=
k
DL
CPU
;
cpu_ctx
.
device_id
=
0
;
size_t
arr_size
=
GetDataSize
(
handle
);
CHECK_EQ
(
arr_size
,
nbytes
)
...
...
@@ -418,7 +418,7 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle,
size_t
nbytes
)
{
API_BEGIN
();
TVMContext
cpu_ctx
;
cpu_ctx
.
device_type
=
kCPU
;
cpu_ctx
.
device_type
=
k
DL
CPU
;
cpu_ctx
.
device_id
=
0
;
size_t
arr_size
=
GetDataSize
(
handle
);
CHECK_EQ
(
arr_size
,
nbytes
)
...
...
src/runtime/cpu_device_api.cc
View file @
8214d6ca
...
...
@@ -68,7 +68,7 @@ class CPUDeviceAPI final : public DeviceAPI {
struct
CPUWorkspacePool
:
public
WorkspacePool
{
CPUWorkspacePool
()
:
WorkspacePool
(
kCPU
,
CPUDeviceAPI
::
Global
())
{}
WorkspacePool
(
k
DL
CPU
,
CPUDeviceAPI
::
Global
())
{}
};
void
*
CPUDeviceAPI
::
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
{
...
...
src/runtime/cuda/cuda_device_api.cc
View file @
8214d6ca
...
...
@@ -79,7 +79,7 @@ class CUDADeviceAPI final : public DeviceAPI {
cudaStream_t
cu_stream
=
static_cast
<
cudaStream_t
>
(
stream
);
from
=
static_cast
<
const
char
*>
(
from
)
+
from_offset
;
to
=
static_cast
<
char
*>
(
to
)
+
to_offset
;
if
(
ctx_from
.
device_type
==
k
GPU
&&
ctx_to
.
device_type
==
k
GPU
)
{
if
(
ctx_from
.
device_type
==
k
DLGPU
&&
ctx_to
.
device_type
==
kDL
GPU
)
{
CUDA_CALL
(
cudaSetDevice
(
ctx_from
.
device_id
));
if
(
ctx_from
.
device_id
==
ctx_to
.
device_id
)
{
GPUCopy
(
from
,
to
,
size
,
cudaMemcpyDeviceToDevice
,
cu_stream
);
...
...
@@ -88,10 +88,10 @@ class CUDADeviceAPI final : public DeviceAPI {
from
,
ctx_from
.
device_id
,
size
,
cu_stream
);
}
}
else
if
(
ctx_from
.
device_type
==
k
GPU
&&
ctx_to
.
device_type
==
k
CPU
)
{
}
else
if
(
ctx_from
.
device_type
==
k
DLGPU
&&
ctx_to
.
device_type
==
kDL
CPU
)
{
CUDA_CALL
(
cudaSetDevice
(
ctx_from
.
device_id
));
GPUCopy
(
from
,
to
,
size
,
cudaMemcpyDeviceToHost
,
cu_stream
);
}
else
if
(
ctx_from
.
device_type
==
k
CPU
&&
ctx_to
.
device_type
==
k
GPU
)
{
}
else
if
(
ctx_from
.
device_type
==
k
DLCPU
&&
ctx_to
.
device_type
==
kDL
GPU
)
{
CUDA_CALL
(
cudaSetDevice
(
ctx_to
.
device_id
));
GPUCopy
(
from
,
to
,
size
,
cudaMemcpyHostToDevice
,
cu_stream
);
}
else
{
...
...
@@ -140,7 +140,7 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef
dmlc
::
ThreadLocalStore
<
CUDAThreadEntry
>
CUDAThreadStore
;
CUDAThreadEntry
::
CUDAThreadEntry
()
:
pool
(
kGPU
,
CUDADeviceAPI
::
Global
())
{
:
pool
(
k
DL
GPU
,
CUDADeviceAPI
::
Global
())
{
}
CUDAThreadEntry
*
CUDAThreadEntry
::
ThreadLocal
()
{
...
...
src/runtime/graph/graph_runtime.cc
View file @
8214d6ca
...
...
@@ -462,7 +462,7 @@ void GraphRuntime::SetupStorage() {
int64_t
shape
[]
=
{
static_cast
<
int64_t
>
(
pool_entry_bytes
[
i
]
+
3
)
/
4
};
DLTensor
*
tensor
;
TVM_CCALL
(
TVMArrayAlloc
(
shape
,
1
,
kFloat
,
32
,
1
,
ctx_
.
device_type
,
ctx_
.
device_id
,
&
tensor
));
shape
,
1
,
k
DL
Float
,
32
,
1
,
ctx_
.
device_type
,
ctx_
.
device_id
,
&
tensor
));
storage_pool_
.
push_back
(
tensor
);
}
// Assign the pooled entries.
...
...
src/runtime/metal/metal_common.h
View file @
8214d6ca
...
...
@@ -45,14 +45,14 @@ class MetalWorkspace final : public DeviceAPI {
~
MetalWorkspace
();
// Get command queue for given context.
id
<
MTLCommandQueue
>
GetCommandQueue
(
TVMContext
ctx
)
{
CHECK_EQ
(
ctx
.
device_type
,
kMetal
);
CHECK_EQ
(
ctx
.
device_type
,
k
DL
Metal
);
CHECK
(
ctx
.
device_id
>=
0
&&
static_cast
<
size_t
>
(
ctx
.
device_id
)
<
queues
.
size
())
<<
"Invalid Metal device_id="
<<
ctx
.
device_id
;
return
queues
[
ctx
.
device_id
];
}
// Get device for given context
id
<
MTLDevice
>
GetDevice
(
TVMContext
ctx
)
{
CHECK_EQ
(
ctx
.
device_type
,
kMetal
);
CHECK_EQ
(
ctx
.
device_type
,
k
DL
Metal
);
CHECK
(
ctx
.
device_id
>=
0
&&
static_cast
<
size_t
>
(
ctx
.
device_id
)
<
devices
.
size
())
<<
"Invalid Metal device_id="
<<
ctx
.
device_id
;
return
devices
[
ctx
.
device_id
];
...
...
@@ -91,9 +91,9 @@ class MetalThreadEntry {
WorkspacePool
pool
;
// constructor
MetalThreadEntry
()
:
pool
(
static_cast
<
DLDeviceType
>
(
kMetal
),
MetalWorkspace
::
Global
())
{
:
pool
(
static_cast
<
DLDeviceType
>
(
k
DL
Metal
),
MetalWorkspace
::
Global
())
{
context
.
device_id
=
0
;
context
.
device_type
=
static_cast
<
DLDeviceType
>
(
kMetal
);
context
.
device_type
=
static_cast
<
DLDeviceType
>
(
k
DL
Metal
);
}
~
MetalThreadEntry
();
// Get temp buffer with at least size under ctx.
...
...
src/runtime/metal/metal_device_api.mm
View file @
8214d6ca
...
...
@@ -150,13 +150,13 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
this->Init();
CHECK(stream == nullptr);
TVMContext ctx = ctx_from;
if (ctx_from.device_type == kCPU) ctx = ctx_to;
if (ctx_from.device_type == k
DL
CPU) ctx = ctx_to;
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type);
if (from_dev_type == k
Metal && to_dev_type == k
Metal) {
if (from_dev_type == k
DLMetal && to_dev_type == kDL
Metal) {
CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
<< "Metal disallow cross device copy.";
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
...
...
@@ -167,7 +167,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
size:size];
[encoder endEncoding];
[cb commit];
} else if (from_dev_type == k
Metal && to_dev_type == k
CPU) {
} else if (from_dev_type == k
DLMetal && to_dev_type == kDL
CPU) {
// copy to a local buffer before get into global buffer.
id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
if (from_buf.storageMode != MTLStorageModeShared) {
...
...
@@ -190,7 +190,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
static_cast<char*>([from_buf contents]) + from_offset,
size);
}
} else if (from_dev_type == k
CPU && to_dev_type == k
Metal) {
} else if (from_dev_type == k
DLCPU && to_dev_type == kDL
Metal) {
id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
if (to_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
...
...
src/runtime/opencl/opencl_common.h
View file @
8214d6ca
...
...
@@ -133,7 +133,7 @@ class OpenCLWorkspace final : public DeviceAPI {
void
Init
();
// get the queue of the context
cl_command_queue
GetQueue
(
TVMContext
ctx
)
{
CHECK_EQ
(
ctx
.
device_type
,
kOpenCL
);
CHECK_EQ
(
ctx
.
device_type
,
k
DL
OpenCL
);
this
->
Init
();
CHECK
(
ctx
.
device_id
>=
0
&&
static_cast
<
size_t
>
(
ctx
.
device_id
)
<
queues
.
size
())
<<
"Invalid OpenCL device_id="
<<
ctx
.
device_id
;
...
...
@@ -178,9 +178,9 @@ class OpenCLThreadEntry {
WorkspacePool
pool
;
// constructor
OpenCLThreadEntry
()
:
pool
(
kOpenCL
,
OpenCLWorkspace
::
Global
())
{
:
pool
(
k
DL
OpenCL
,
OpenCLWorkspace
::
Global
())
{
context
.
device_id
=
0
;
context
.
device_type
=
kOpenCL
;
context
.
device_type
=
k
DL
OpenCL
;
}
// get the global workspace
static
OpenCLThreadEntry
*
ThreadLocal
();
...
...
src/runtime/opencl/opencl_device_api.cc
View file @
8214d6ca
...
...
@@ -76,13 +76,13 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from,
TVMStreamHandle
stream
)
{
this
->
Init
();
CHECK
(
stream
==
nullptr
);
if
(
ctx_from
.
device_type
==
k
OpenCL
&&
ctx_to
.
device_type
==
k
OpenCL
)
{
if
(
ctx_from
.
device_type
==
k
DLOpenCL
&&
ctx_to
.
device_type
==
kDL
OpenCL
)
{
OPENCL_CALL
(
clEnqueueCopyBuffer
(
this
->
GetQueue
(
ctx_to
),
static_cast
<
cl_mem
>
((
void
*
)
from
),
// NOLINT(*)
static_cast
<
cl_mem
>
(
to
),
from_offset
,
to_offset
,
size
,
0
,
nullptr
,
nullptr
));
}
else
if
(
ctx_from
.
device_type
==
k
OpenCL
&&
ctx_to
.
device_type
==
k
CPU
)
{
}
else
if
(
ctx_from
.
device_type
==
k
DLOpenCL
&&
ctx_to
.
device_type
==
kDL
CPU
)
{
OPENCL_CALL
(
clEnqueueReadBuffer
(
this
->
GetQueue
(
ctx_from
),
static_cast
<
cl_mem
>
((
void
*
)
from
),
// NOLINT(*)
...
...
@@ -90,7 +90,7 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from,
static_cast
<
char
*>
(
to
)
+
to_offset
,
0
,
nullptr
,
nullptr
));
OPENCL_CALL
(
clFinish
(
this
->
GetQueue
(
ctx_from
)));
}
else
if
(
ctx_from
.
device_type
==
k
CPU
&&
ctx_to
.
device_type
==
k
OpenCL
)
{
}
else
if
(
ctx_from
.
device_type
==
k
DLCPU
&&
ctx_to
.
device_type
==
kDL
OpenCL
)
{
OPENCL_CALL
(
clEnqueueWriteBuffer
(
this
->
GetQueue
(
ctx_to
),
static_cast
<
cl_mem
>
(
to
),
...
...
src/runtime/pack_args.h
View file @
8214d6ca
...
...
@@ -104,12 +104,12 @@ enum ArgConvertCode {
inline
ArgConvertCode
GetArgConvertCode
(
TVMType
t
)
{
CHECK_EQ
(
t
.
lanes
,
1U
)
<<
"Cannot pass vector type argument to devic function for now"
;
if
(
t
.
code
==
kInt
)
{
if
(
t
.
code
==
k
DL
Int
)
{
if
(
t
.
bits
==
64U
)
return
INT64_TO_INT64
;
if
(
t
.
bits
==
32U
)
return
INT64_TO_INT32
;
}
else
if
(
t
.
code
==
kUInt
)
{
}
else
if
(
t
.
code
==
k
DL
UInt
)
{
if
(
t
.
bits
==
32U
)
return
INT64_TO_UINT32
;
}
else
if
(
t
.
code
==
kFloat
)
{
}
else
if
(
t
.
code
==
k
DL
Float
)
{
if
(
t
.
bits
==
64U
)
return
FLOAT64_TO_FLOAT64
;
if
(
t
.
bits
==
32U
)
return
FLOAT64_TO_FLOAT32
;
}
else
if
(
t
.
code
==
kHandle
)
{
...
...
src/runtime/rocm/rocm_device_api.cc
View file @
8214d6ca
...
...
@@ -77,7 +77,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
hipStream_t
hip_stream
=
static_cast
<
hipStream_t
>
(
stream
);
from
=
static_cast
<
const
char
*>
(
from
)
+
from_offset
;
to
=
static_cast
<
char
*>
(
to
)
+
to_offset
;
if
(
ctx_from
.
device_type
==
k
ROCM
&&
ctx_to
.
device_type
==
k
ROCM
)
{
if
(
ctx_from
.
device_type
==
k
DLROCM
&&
ctx_to
.
device_type
==
kDL
ROCM
)
{
ROCM_CALL
(
hipSetDevice
(
ctx_from
.
device_id
));
if
(
ctx_from
.
device_id
==
ctx_to
.
device_id
)
{
GPUCopy
(
from
,
to
,
size
,
hipMemcpyDeviceToDevice
,
hip_stream
);
...
...
@@ -86,10 +86,10 @@ class ROCMDeviceAPI final : public DeviceAPI {
from
,
ctx_from
.
device_id
,
size
,
hip_stream
);
}
}
else
if
(
ctx_from
.
device_type
==
k
ROCM
&&
ctx_to
.
device_type
==
k
CPU
)
{
}
else
if
(
ctx_from
.
device_type
==
k
DLROCM
&&
ctx_to
.
device_type
==
kDL
CPU
)
{
ROCM_CALL
(
hipSetDevice
(
ctx_from
.
device_id
));
GPUCopy
(
from
,
to
,
size
,
hipMemcpyDeviceToHost
,
hip_stream
);
}
else
if
(
ctx_from
.
device_type
==
k
CPU
&&
ctx_to
.
device_type
==
k
ROCM
)
{
}
else
if
(
ctx_from
.
device_type
==
k
DLCPU
&&
ctx_to
.
device_type
==
kDL
ROCM
)
{
ROCM_CALL
(
hipSetDevice
(
ctx_to
.
device_id
));
GPUCopy
(
from
,
to
,
size
,
hipMemcpyHostToDevice
,
hip_stream
);
}
else
{
...
...
@@ -138,7 +138,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
typedef
dmlc
::
ThreadLocalStore
<
ROCMThreadEntry
>
ROCMThreadStore
;
ROCMThreadEntry
::
ROCMThreadEntry
()
:
pool
(
kROCM
,
ROCMDeviceAPI
::
Global
())
{
:
pool
(
k
DL
ROCM
,
ROCMDeviceAPI
::
Global
())
{
}
ROCMThreadEntry
*
ROCMThreadEntry
::
ThreadLocal
()
{
...
...
src/runtime/rpc/rpc_device_api.cc
View file @
8214d6ca
...
...
@@ -55,12 +55,12 @@ class RPCDeviceAPI final : public DeviceAPI {
static_cast
<
const
RemoteSpace
*>
(
to
)
->
data
,
to_offset
,
size
,
ctx_from
,
ctx_to
,
stream
);
}
else
if
(
from_dev_type
>
kRPCSessMask
&&
to_dev_type
==
kCPU
)
{
to_dev_type
==
k
DL
CPU
)
{
GetSess
(
ctx_from
)
->
CopyFromRemote
(
static_cast
<
const
RemoteSpace
*>
(
from
)
->
data
,
from_offset
,
to
,
to_offset
,
size
,
ctx_from
);
}
else
if
(
from_dev_type
==
kCPU
&&
}
else
if
(
from_dev_type
==
k
DL
CPU
&&
to_dev_type
>
kRPCSessMask
)
{
GetSess
(
ctx_to
)
->
CopyToRemote
(
(
void
*
)
from
,
from_offset
,
// NOLINT(*)
...
...
src/runtime/rpc/rpc_session.cc
View file @
8214d6ca
...
...
@@ -162,9 +162,9 @@ class RPCSession::EventHandler {
int
tcode
=
type_codes
[
i
];
TVMValue
value
=
arg_values
[
i
];
switch
(
tcode
)
{
case
kInt
:
case
kUInt
:
case
kFloat
:
case
k
DL
Int
:
case
k
DL
UInt
:
case
k
DL
Float
:
case
kTVMType
:
{
writer_
->
Write
(
&
value
,
sizeof
(
TVMValue
));
break
;
...
...
@@ -315,9 +315,9 @@ class RPCSession::EventHandler {
int
tcode
=
arg_buf_
->
tcode
[
arg_index_
];
static_assert
(
sizeof
(
TVMValue
)
==
sizeof
(
uint64_t
),
"invariant"
);
switch
(
tcode
)
{
case
kInt
:
case
kUInt
:
case
kFloat
:
case
k
DL
Int
:
case
k
DL
UInt
:
case
k
DL
Float
:
case
kTVMType
:
case
kHandle
:
case
kStr
:
...
...
@@ -352,9 +352,9 @@ class RPCSession::EventHandler {
TVMValue
&
value
=
arg_buf_
->
value
[
arg_index_
];
if
(
arg_recv_stage_
==
0
)
{
switch
(
tcode
)
{
case
kInt
:
case
kUInt
:
case
kFloat
:
case
k
DL
Int
:
case
k
DL
UInt
:
case
k
DL
Float
:
case
kTVMType
:
case
kTVMContext
:
{
this
->
Read
(
&
value
,
sizeof
(
TVMValue
));
...
...
@@ -484,7 +484,7 @@ class RPCSession::EventHandler {
this
->
Read
(
&
offset
,
sizeof
(
offset
));
this
->
Read
(
&
size
,
sizeof
(
size
));
this
->
Read
(
&
ctx
,
sizeof
(
ctx
));
if
(
ctx
.
device_type
==
kCPU
)
{
if
(
ctx
.
device_type
==
k
DL
CPU
)
{
RPCCode
code
=
RPCCode
::
kCopyAck
;
writer_
->
Write
(
&
code
,
sizeof
(
code
));
writer_
->
Write
(
reinterpret_cast
<
char
*>
(
handle
)
+
offset
,
size
);
...
...
@@ -492,7 +492,7 @@ class RPCSession::EventHandler {
temp_data_
.
resize
(
size
+
1
);
try
{
TVMContext
cpu_ctx
;
cpu_ctx
.
device_type
=
kCPU
;
cpu_ctx
.
device_type
=
k
DL
CPU
;
cpu_ctx
.
device_id
=
0
;
DeviceAPI
::
Get
(
ctx
)
->
CopyDataFromTo
(
reinterpret_cast
<
void
*>
(
handle
),
offset
,
...
...
@@ -531,7 +531,7 @@ class RPCSession::EventHandler {
int
ret_tcode
=
kNull
;
RPCCode
code
=
RPCCode
::
kReturn
;
std
::
string
errmsg
;
if
(
copy_ctx_
.
device_type
==
kCPU
)
{
if
(
copy_ctx_
.
device_type
==
k
DL
CPU
)
{
this
->
Read
(
reinterpret_cast
<
char
*>
(
copy_handle_
)
+
copy_offset_
,
copy_size_
);
}
else
{
...
...
@@ -539,7 +539,7 @@ class RPCSession::EventHandler {
this
->
Read
(
&
temp_data_
[
0
],
copy_size_
);
try
{
TVMContext
cpu_ctx
;
cpu_ctx
.
device_type
=
kCPU
;
cpu_ctx
.
device_type
=
k
DL
CPU
;
cpu_ctx
.
device_id
=
0
;
DeviceAPI
::
Get
(
copy_ctx_
)
->
CopyDataFromTo
(
temp_data_
.
data
(),
0
,
...
...
@@ -915,10 +915,10 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
TVMContext
ctx_to
=
args
[
6
];
TVMStreamHandle
stream
=
args
[
7
];
TVMContext
ctx
=
ctx_from
;
if
(
ctx
.
device_type
==
kCPU
)
{
if
(
ctx
.
device_type
==
k
DL
CPU
)
{
ctx
=
ctx_to
;
}
else
{
CHECK
(
ctx_to
.
device_type
==
kCPU
||
CHECK
(
ctx_to
.
device_type
==
k
DL
CPU
||
ctx_to
.
device_type
==
ctx_from
.
device_type
)
<<
"Can not copy across different ctx types directly"
;
}
...
...
tests/cpp/packed_func_test.cc
View file @
8214d6ca
...
...
@@ -14,7 +14,7 @@ TEST(PackedFunc, Basic) {
Var
v
=
PackedFunc
([
&
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
CHECK
(
args
.
num_args
==
3
);
CHECK
(
args
.
values
[
0
].
v_float64
==
1.0
);
CHECK
(
args
.
type_codes
[
0
]
==
kFloat
);
CHECK
(
args
.
type_codes
[
0
]
==
k
DL
Float
);
CHECK
(
args
.
values
[
1
].
v_handle
==
&
a
);
CHECK
(
args
.
type_codes
[
1
]
==
kArrayHandle
);
CHECK
(
args
.
values
[
2
].
v_handle
==
&
x
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment