Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
8214d6ca
Commit
8214d6ca
authored
Nov 03, 2017
by
Tianqi Chen
Committed by
GitHub
Nov 03, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[DLPack] Upgrade dlpack to 0.2 (#609)
parent
a152a9cb
Show whitespace changes
Inline
Side-by-side
Showing
31 changed files
with
122 additions
and
121 deletions
+122
-121
apps/howto_deploy/cpp_deploy.cc
+2
-2
dlpack
+1
-1
include/tvm/packed_func_ext.h
+2
-2
include/tvm/runtime/packed_func.h
+18
-18
jvm/native/src/main/native/jni_helper_func.h
+3
-3
jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc
+2
-2
src/api/api_lang.cc
+2
-2
src/codegen/llvm/codegen_amdgpu.cc
+1
-1
src/codegen/llvm/codegen_llvm.cc
+1
-1
src/codegen/llvm/codegen_nvptx.cc
+1
-1
src/codegen/stack_vm/stack_vm.h
+6
-6
src/codegen/verilog/vpi_device_api.cc
+2
-2
src/contrib/cblas/cblas.cc
+3
-3
src/contrib/cudnn/cudnn_utils.cc
+3
-3
src/contrib/nnpack/convolution.cc
+8
-8
src/contrib/nnpack/fully_connected.cc
+6
-6
src/pass/lower_tvm_builtin.cc
+1
-1
src/pass/make_api.cc
+4
-3
src/runtime/c_runtime_api.cc
+11
-11
src/runtime/cpu_device_api.cc
+1
-1
src/runtime/cuda/cuda_device_api.cc
+4
-4
src/runtime/graph/graph_runtime.cc
+1
-1
src/runtime/metal/metal_common.h
+4
-4
src/runtime/metal/metal_device_api.mm
+4
-4
src/runtime/opencl/opencl_common.h
+3
-3
src/runtime/opencl/opencl_device_api.cc
+3
-3
src/runtime/pack_args.h
+3
-3
src/runtime/rocm/rocm_device_api.cc
+4
-4
src/runtime/rpc/rpc_device_api.cc
+2
-2
src/runtime/rpc/rpc_session.cc
+15
-15
tests/cpp/packed_func_test.cc
+1
-1
No files found.
apps/howto_deploy/cpp_deploy.cc
View file @
8214d6ca
...
@@ -28,10 +28,10 @@ void Verify(tvm::runtime::Module mod, std::string fname) {
...
@@ -28,10 +28,10 @@ void Verify(tvm::runtime::Module mod, std::string fname) {
DLTensor
*
x
;
DLTensor
*
x
;
DLTensor
*
y
;
DLTensor
*
y
;
int
ndim
=
1
;
int
ndim
=
1
;
int
dtype_code
=
kFloat
;
int
dtype_code
=
k
DL
Float
;
int
dtype_bits
=
32
;
int
dtype_bits
=
32
;
int
dtype_lanes
=
1
;
int
dtype_lanes
=
1
;
int
device_type
=
kCPU
;
int
device_type
=
k
DL
CPU
;
int
device_id
=
0
;
int
device_id
=
0
;
int64_t
shape
[
1
]
=
{
10
};
int64_t
shape
[
1
]
=
{
10
};
TVMArrayAlloc
(
shape
,
ndim
,
dtype_code
,
dtype_bits
,
dtype_lanes
,
TVMArrayAlloc
(
shape
,
ndim
,
dtype_code
,
dtype_bits
,
dtype_lanes
,
...
...
dlpack
@
10892ac9
Subproject commit
9422e98f3f4dafc6bc3473cf8484543ad376aab6
Subproject commit
10892ac964f1af7c81aae145cd3fab78bbccd297
include/tvm/packed_func_ext.h
View file @
8214d6ca
...
@@ -105,10 +105,10 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
...
@@ -105,10 +105,10 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
inline
TVMArgValue
::
operator
Halide
::
Expr
()
const
{
inline
TVMArgValue
::
operator
Halide
::
Expr
()
const
{
if
(
type_code_
==
kNull
)
return
Expr
();
if
(
type_code_
==
kNull
)
return
Expr
();
if
(
type_code_
==
kInt
)
{
if
(
type_code_
==
k
DL
Int
)
{
return
Expr
(
static_cast
<
int
>
(
value_
.
v_int64
));
return
Expr
(
static_cast
<
int
>
(
value_
.
v_int64
));
}
}
if
(
type_code_
==
kFloat
)
{
if
(
type_code_
==
k
DL
Float
)
{
return
Expr
(
static_cast
<
float
>
(
value_
.
v_float64
));
return
Expr
(
static_cast
<
float
>
(
value_
.
v_float64
));
}
}
TVM_CHECK_TYPE_CODE
(
type_code_
,
kNodeHandle
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
kNodeHandle
);
...
...
include/tvm/runtime/packed_func.h
View file @
8214d6ca
...
@@ -217,25 +217,25 @@ class ExtTypeVTable {
...
@@ -217,25 +217,25 @@ class ExtTypeVTable {
class
TVMPODValue_
{
class
TVMPODValue_
{
public
:
public
:
operator
double
()
const
{
operator
double
()
const
{
TVM_CHECK_TYPE_CODE
(
type_code_
,
kFloat
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
k
DL
Float
);
return
value_
.
v_float64
;
return
value_
.
v_float64
;
}
}
operator
int64_t
()
const
{
operator
int64_t
()
const
{
TVM_CHECK_TYPE_CODE
(
type_code_
,
kInt
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
k
DL
Int
);
return
value_
.
v_int64
;
return
value_
.
v_int64
;
}
}
operator
uint64_t
()
const
{
operator
uint64_t
()
const
{
TVM_CHECK_TYPE_CODE
(
type_code_
,
kInt
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
k
DL
Int
);
return
value_
.
v_int64
;
return
value_
.
v_int64
;
}
}
operator
int
()
const
{
operator
int
()
const
{
TVM_CHECK_TYPE_CODE
(
type_code_
,
kInt
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
k
DL
Int
);
CHECK_LE
(
value_
.
v_int64
,
CHECK_LE
(
value_
.
v_int64
,
std
::
numeric_limits
<
int
>::
max
());
std
::
numeric_limits
<
int
>::
max
());
return
static_cast
<
int
>
(
value_
.
v_int64
);
return
static_cast
<
int
>
(
value_
.
v_int64
);
}
}
operator
bool
()
const
{
operator
bool
()
const
{
TVM_CHECK_TYPE_CODE
(
type_code_
,
kInt
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
k
DL
Int
);
return
value_
.
v_int64
!=
0
;
return
value_
.
v_int64
!=
0
;
}
}
operator
void
*
()
const
{
operator
void
*
()
const
{
...
@@ -430,7 +430,7 @@ class TVMRetValue : public TVMPODValue_ {
...
@@ -430,7 +430,7 @@ class TVMRetValue : public TVMPODValue_ {
return
*
this
;
return
*
this
;
}
}
TVMRetValue
&
operator
=
(
double
value
)
{
TVMRetValue
&
operator
=
(
double
value
)
{
this
->
SwitchToPOD
(
kFloat
);
this
->
SwitchToPOD
(
k
DL
Float
);
value_
.
v_float64
=
value
;
value_
.
v_float64
=
value
;
return
*
this
;
return
*
this
;
}
}
...
@@ -445,12 +445,12 @@ class TVMRetValue : public TVMPODValue_ {
...
@@ -445,12 +445,12 @@ class TVMRetValue : public TVMPODValue_ {
return
*
this
;
return
*
this
;
}
}
TVMRetValue
&
operator
=
(
int64_t
value
)
{
TVMRetValue
&
operator
=
(
int64_t
value
)
{
this
->
SwitchToPOD
(
kInt
);
this
->
SwitchToPOD
(
k
DL
Int
);
value_
.
v_int64
=
value
;
value_
.
v_int64
=
value
;
return
*
this
;
return
*
this
;
}
}
TVMRetValue
&
operator
=
(
int
value
)
{
TVMRetValue
&
operator
=
(
int
value
)
{
this
->
SwitchToPOD
(
kInt
);
this
->
SwitchToPOD
(
k
DL
Int
);
value_
.
v_int64
=
value
;
value_
.
v_int64
=
value
;
return
*
this
;
return
*
this
;
}
}
...
@@ -460,7 +460,7 @@ class TVMRetValue : public TVMPODValue_ {
...
@@ -460,7 +460,7 @@ class TVMRetValue : public TVMPODValue_ {
return
*
this
;
return
*
this
;
}
}
TVMRetValue
&
operator
=
(
bool
value
)
{
TVMRetValue
&
operator
=
(
bool
value
)
{
this
->
SwitchToPOD
(
kInt
);
this
->
SwitchToPOD
(
k
DL
Int
);
value_
.
v_int64
=
value
;
value_
.
v_int64
=
value
;
return
*
this
;
return
*
this
;
}
}
...
@@ -609,9 +609,9 @@ class TVMRetValue : public TVMPODValue_ {
...
@@ -609,9 +609,9 @@ class TVMRetValue : public TVMPODValue_ {
// implementation details
// implementation details
inline
const
char
*
TypeCode2Str
(
int
type_code
)
{
inline
const
char
*
TypeCode2Str
(
int
type_code
)
{
switch
(
type_code
)
{
switch
(
type_code
)
{
case
kInt
:
return
"int"
;
case
k
DL
Int
:
return
"int"
;
case
kUInt
:
return
"uint"
;
case
k
DL
UInt
:
return
"uint"
;
case
kFloat
:
return
"float"
;
case
k
DL
Float
:
return
"float"
;
case
kStr
:
return
"str"
;
case
kStr
:
return
"str"
;
case
kBytes
:
return
"bytes"
;
case
kBytes
:
return
"bytes"
;
case
kHandle
:
return
"handle"
;
case
kHandle
:
return
"handle"
;
...
@@ -648,11 +648,11 @@ inline TVMType String2TVMType(std::string s) {
...
@@ -648,11 +648,11 @@ inline TVMType String2TVMType(std::string s) {
t
.
bits
=
32
;
t
.
lanes
=
1
;
t
.
bits
=
32
;
t
.
lanes
=
1
;
const
char
*
scan
;
const
char
*
scan
;
if
(
s
.
substr
(
0
,
3
)
==
"int"
)
{
if
(
s
.
substr
(
0
,
3
)
==
"int"
)
{
t
.
code
=
kInt
;
scan
=
s
.
c_str
()
+
3
;
t
.
code
=
k
DL
Int
;
scan
=
s
.
c_str
()
+
3
;
}
else
if
(
s
.
substr
(
0
,
4
)
==
"uint"
)
{
}
else
if
(
s
.
substr
(
0
,
4
)
==
"uint"
)
{
t
.
code
=
kUInt
;
scan
=
s
.
c_str
()
+
4
;
t
.
code
=
k
DL
UInt
;
scan
=
s
.
c_str
()
+
4
;
}
else
if
(
s
.
substr
(
0
,
5
)
==
"float"
)
{
}
else
if
(
s
.
substr
(
0
,
5
)
==
"float"
)
{
t
.
code
=
kFloat
;
scan
=
s
.
c_str
()
+
5
;
t
.
code
=
k
DL
Float
;
scan
=
s
.
c_str
()
+
5
;
}
else
if
(
s
.
substr
(
0
,
6
)
==
"handle"
)
{
}
else
if
(
s
.
substr
(
0
,
6
)
==
"handle"
)
{
t
.
code
=
kHandle
;
t
.
code
=
kHandle
;
t
.
bits
=
64
;
// handle uses 64 bit by default.
t
.
bits
=
64
;
// handle uses 64 bit by default.
...
@@ -724,17 +724,17 @@ class TVMArgsSetter {
...
@@ -724,17 +724,17 @@ class TVMArgsSetter {
std
::
is_integral
<
T
>::
value
>::
type
>
std
::
is_integral
<
T
>::
value
>::
type
>
void
operator
()(
size_t
i
,
T
value
)
const
{
void
operator
()(
size_t
i
,
T
value
)
const
{
values_
[
i
].
v_int64
=
static_cast
<
int64_t
>
(
value
);
values_
[
i
].
v_int64
=
static_cast
<
int64_t
>
(
value
);
type_codes_
[
i
]
=
kInt
;
type_codes_
[
i
]
=
k
DL
Int
;
}
}
void
operator
()(
size_t
i
,
uint64_t
value
)
const
{
void
operator
()(
size_t
i
,
uint64_t
value
)
const
{
values_
[
i
].
v_int64
=
static_cast
<
int64_t
>
(
value
);
values_
[
i
].
v_int64
=
static_cast
<
int64_t
>
(
value
);
CHECK_LE
(
value
,
CHECK_LE
(
value
,
static_cast
<
uint64_t
>
(
std
::
numeric_limits
<
int64_t
>::
max
()));
static_cast
<
uint64_t
>
(
std
::
numeric_limits
<
int64_t
>::
max
()));
type_codes_
[
i
]
=
kInt
;
type_codes_
[
i
]
=
k
DL
Int
;
}
}
void
operator
()(
size_t
i
,
double
value
)
const
{
void
operator
()(
size_t
i
,
double
value
)
const
{
values_
[
i
].
v_float64
=
value
;
values_
[
i
].
v_float64
=
value
;
type_codes_
[
i
]
=
kFloat
;
type_codes_
[
i
]
=
k
DL
Float
;
}
}
void
operator
()(
size_t
i
,
std
::
nullptr_t
value
)
const
{
void
operator
()(
size_t
i
,
std
::
nullptr_t
value
)
const
{
values_
[
i
].
v_handle
=
value
;
values_
[
i
].
v_handle
=
value
;
...
...
jvm/native/src/main/native/jni_helper_func.h
View file @
8214d6ca
...
@@ -161,10 +161,10 @@ void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) {
...
@@ -161,10 +161,10 @@ void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) {
jobject
tvmRetValueToJava
(
JNIEnv
*
env
,
TVMValue
value
,
int
tcode
)
{
jobject
tvmRetValueToJava
(
JNIEnv
*
env
,
TVMValue
value
,
int
tcode
)
{
switch
(
tcode
)
{
switch
(
tcode
)
{
case
kUInt
:
case
k
DL
UInt
:
case
kInt
:
case
k
DL
Int
:
return
newTVMValueLong
(
env
,
static_cast
<
jlong
>
(
value
.
v_int64
));
return
newTVMValueLong
(
env
,
static_cast
<
jlong
>
(
value
.
v_int64
));
case
kFloat
:
case
k
DL
Float
:
return
newTVMValueDouble
(
env
,
static_cast
<
jdouble
>
(
value
.
v_float64
));
return
newTVMValueDouble
(
env
,
static_cast
<
jdouble
>
(
value
.
v_float64
));
case
kModuleHandle
:
case
kModuleHandle
:
return
newModule
(
env
,
reinterpret_cast
<
jlong
>
(
value
.
v_handle
));
return
newModule
(
env
,
reinterpret_cast
<
jlong
>
(
value
.
v_handle
));
...
...
jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc
View file @
8214d6ca
...
@@ -62,7 +62,7 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgLong(
...
@@ -62,7 +62,7 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgLong(
value
.
v_int64
=
static_cast
<
int64_t
>
(
arg
);
value
.
v_int64
=
static_cast
<
int64_t
>
(
arg
);
TVMFuncArgsThreadLocalEntry
*
e
=
TVMFuncArgsThreadLocalStore
::
Get
();
TVMFuncArgsThreadLocalEntry
*
e
=
TVMFuncArgsThreadLocalStore
::
Get
();
e
->
tvmFuncArgValues
.
push_back
(
value
);
e
->
tvmFuncArgValues
.
push_back
(
value
);
e
->
tvmFuncArgTypes
.
push_back
(
kInt
);
e
->
tvmFuncArgTypes
.
push_back
(
k
DL
Int
);
}
}
JNIEXPORT
void
JNICALL
Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble
(
JNIEXPORT
void
JNICALL
Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble
(
...
@@ -71,7 +71,7 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble(
...
@@ -71,7 +71,7 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble(
value
.
v_float64
=
static_cast
<
double
>
(
arg
);
value
.
v_float64
=
static_cast
<
double
>
(
arg
);
TVMFuncArgsThreadLocalEntry
*
e
=
TVMFuncArgsThreadLocalStore
::
Get
();
TVMFuncArgsThreadLocalEntry
*
e
=
TVMFuncArgsThreadLocalStore
::
Get
();
e
->
tvmFuncArgValues
.
push_back
(
value
);
e
->
tvmFuncArgValues
.
push_back
(
value
);
e
->
tvmFuncArgTypes
.
push_back
(
kFloat
);
e
->
tvmFuncArgTypes
.
push_back
(
k
DL
Float
);
}
}
JNIEXPORT
void
JNICALL
Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgString
(
JNIEXPORT
void
JNICALL
Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgString
(
...
...
src/api/api_lang.cc
View file @
8214d6ca
...
@@ -27,9 +27,9 @@ TVM_REGISTER_API("_max_value")
...
@@ -27,9 +27,9 @@ TVM_REGISTER_API("_max_value")
TVM_REGISTER_API
(
"_const"
)
TVM_REGISTER_API
(
"_const"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
if
(
args
[
0
].
type_code
()
==
kInt
)
{
if
(
args
[
0
].
type_code
()
==
k
DL
Int
)
{
*
ret
=
make_const
(
args
[
1
],
args
[
0
].
operator
int64_t
());
*
ret
=
make_const
(
args
[
1
],
args
[
0
].
operator
int64_t
());
}
else
if
(
args
[
0
].
type_code
()
==
kFloat
)
{
}
else
if
(
args
[
0
].
type_code
()
==
k
DL
Float
)
{
*
ret
=
make_const
(
args
[
1
],
args
[
0
].
operator
double
());
*
ret
=
make_const
(
args
[
1
],
args
[
0
].
operator
double
());
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"only accept int or float"
;
LOG
(
FATAL
)
<<
"only accept int or float"
;
...
...
src/codegen/llvm/codegen_amdgpu.cc
View file @
8214d6ca
...
@@ -133,7 +133,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
...
@@ -133,7 +133,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
inline
int
DetectROCMComputeVersion
()
{
inline
int
DetectROCMComputeVersion
()
{
TVMContext
tvm_ctx
;
TVMContext
tvm_ctx
;
tvm_ctx
.
device_type
=
kROCM
;
tvm_ctx
.
device_type
=
k
DL
ROCM
;
tvm_ctx
.
device_id
=
0
;
tvm_ctx
.
device_id
=
0
;
TVMRetValue
val
;
TVMRetValue
val
;
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
)
->
GetAttr
(
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
)
->
GetAttr
(
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
8214d6ca
...
@@ -242,7 +242,7 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
...
@@ -242,7 +242,7 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
CHECK_EQ
(
t
.
lanes
(),
1
);
CHECK_EQ
(
t
.
lanes
(),
1
);
return
t_void_p_
;
return
t_void_p_
;
}
}
llvm
::
Type
*
etype
;
llvm
::
Type
*
etype
=
nullptr
;
if
(
t
.
is_int
()
||
t
.
is_uint
())
{
if
(
t
.
is_int
()
||
t
.
is_uint
())
{
etype
=
llvm
::
Type
::
getIntNTy
(
*
ctx_
,
t
.
bits
());
etype
=
llvm
::
Type
::
getIntNTy
(
*
ctx_
,
t
.
bits
());
}
else
if
(
t
.
is_float
())
{
}
else
if
(
t
.
is_float
())
{
...
...
src/codegen/llvm/codegen_nvptx.cc
View file @
8214d6ca
...
@@ -132,7 +132,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
...
@@ -132,7 +132,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
inline
int
DetectCUDAComputeVersion
()
{
inline
int
DetectCUDAComputeVersion
()
{
TVMContext
tvm_ctx
;
TVMContext
tvm_ctx
;
tvm_ctx
.
device_type
=
kGPU
;
tvm_ctx
.
device_type
=
k
DL
GPU
;
tvm_ctx
.
device_id
=
0
;
tvm_ctx
.
device_id
=
0
;
TVMRetValue
val
;
TVMRetValue
val
;
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
)
->
GetAttr
(
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
)
->
GetAttr
(
...
...
src/codegen/stack_vm/stack_vm.h
View file @
8214d6ca
...
@@ -340,16 +340,16 @@ class StackVM {
...
@@ -340,16 +340,16 @@ class StackVM {
static
OpCode
GetLoad
(
TVMType
t
)
{
static
OpCode
GetLoad
(
TVMType
t
)
{
CHECK_EQ
(
t
.
lanes
,
1U
);
CHECK_EQ
(
t
.
lanes
,
1U
);
if
(
t
.
code
==
kHandle
)
return
ARRAY_LOAD_HANDLE
;
if
(
t
.
code
==
kHandle
)
return
ARRAY_LOAD_HANDLE
;
if
(
t
.
code
==
kInt
)
{
if
(
t
.
code
==
k
DL
Int
)
{
switch
(
t
.
bits
)
{
switch
(
t
.
bits
)
{
case
32
:
return
ARRAY_LOAD_INT32
;
case
32
:
return
ARRAY_LOAD_INT32
;
case
64
:
return
ARRAY_LOAD_INT64
;
case
64
:
return
ARRAY_LOAD_INT64
;
}
}
}
else
if
(
t
.
code
==
kUInt
)
{
}
else
if
(
t
.
code
==
k
DL
UInt
)
{
switch
(
t
.
bits
)
{
switch
(
t
.
bits
)
{
case
32
:
return
ARRAY_LOAD_UINT32
;
case
32
:
return
ARRAY_LOAD_UINT32
;
}
}
}
else
if
(
t
.
code
==
kFloat
)
{
}
else
if
(
t
.
code
==
k
DL
Float
)
{
switch
(
t
.
bits
)
{
switch
(
t
.
bits
)
{
case
64
:
return
ARRAY_LOAD_FP64
;
case
64
:
return
ARRAY_LOAD_FP64
;
}
}
...
@@ -365,16 +365,16 @@ class StackVM {
...
@@ -365,16 +365,16 @@ class StackVM {
static
OpCode
GetStore
(
TVMType
t
)
{
static
OpCode
GetStore
(
TVMType
t
)
{
CHECK_EQ
(
t
.
lanes
,
1U
);
CHECK_EQ
(
t
.
lanes
,
1U
);
if
(
t
.
code
==
kHandle
)
return
ARRAY_STORE_HANDLE
;
if
(
t
.
code
==
kHandle
)
return
ARRAY_STORE_HANDLE
;
if
(
t
.
code
==
kInt
)
{
if
(
t
.
code
==
k
DL
Int
)
{
switch
(
t
.
bits
)
{
switch
(
t
.
bits
)
{
case
32
:
return
ARRAY_STORE_INT32
;
case
32
:
return
ARRAY_STORE_INT32
;
case
64
:
return
ARRAY_STORE_INT64
;
case
64
:
return
ARRAY_STORE_INT64
;
}
}
}
else
if
(
t
.
code
==
kUInt
)
{
}
else
if
(
t
.
code
==
k
DL
UInt
)
{
switch
(
t
.
bits
)
{
switch
(
t
.
bits
)
{
case
32
:
return
ARRAY_STORE_UINT32
;
case
32
:
return
ARRAY_STORE_UINT32
;
}
}
}
else
if
(
t
.
code
==
kFloat
)
{
}
else
if
(
t
.
code
==
k
DL
Float
)
{
switch
(
t
.
bits
)
{
switch
(
t
.
bits
)
{
case
64
:
return
ARRAY_STORE_FP64
;
case
64
:
return
ARRAY_STORE_FP64
;
}
}
...
...
src/codegen/verilog/vpi_device_api.cc
View file @
8214d6ca
...
@@ -91,10 +91,10 @@ class VPIDeviceAPI final : public runtime::DeviceAPI {
...
@@ -91,10 +91,10 @@ class VPIDeviceAPI final : public runtime::DeviceAPI {
TVMContext
ctx_from
,
TVMContext
ctx_from
,
TVMContext
ctx_to
,
TVMContext
ctx_to
,
TVMStreamHandle
stream
)
final
{
TVMStreamHandle
stream
)
final
{
if
(
static_cast
<
int
>
(
ctx_from
.
device_type
)
==
kVPI
)
{
if
(
static_cast
<
int
>
(
ctx_from
.
device_type
)
==
k
DL
VPI
)
{
from
=
RealAddr
(
static_cast
<
const
char
*>
(
from
)
+
from_offset
,
size
);
from
=
RealAddr
(
static_cast
<
const
char
*>
(
from
)
+
from_offset
,
size
);
}
}
if
(
static_cast
<
int
>
(
ctx_to
.
device_type
)
==
kVPI
)
{
if
(
static_cast
<
int
>
(
ctx_to
.
device_type
)
==
k
DL
VPI
)
{
to
=
RealAddr
(
static_cast
<
char
*>
(
to
)
+
to_offset
,
size
);
to
=
RealAddr
(
static_cast
<
char
*>
(
to
)
+
to_offset
,
size
);
}
}
memcpy
(
to
,
from
,
size
);
memcpy
(
to
,
from
,
size
);
...
...
src/contrib/cblas/cblas.cc
View file @
8214d6ca
...
@@ -30,9 +30,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
...
@@ -30,9 +30,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
CHECK
(
C
->
strides
==
nullptr
);
CHECK
(
C
->
strides
==
nullptr
);
CHECK
(
B
->
strides
==
nullptr
);
CHECK
(
B
->
strides
==
nullptr
);
CHECK
(
A
->
strides
==
nullptr
);
CHECK
(
A
->
strides
==
nullptr
);
CHECK
(
TypeMatch
(
A
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
A
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
B
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
B
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
C
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
C
->
dtype
,
k
DL
Float
,
32
));
cblas_sgemm
(
CblasColMajor
,
cblas_sgemm
(
CblasColMajor
,
transb
?
CblasTrans
:
CblasNoTrans
,
transb
?
CblasTrans
:
CblasNoTrans
,
transa
?
CblasTrans
:
CblasNoTrans
,
transa
?
CblasTrans
:
CblasNoTrans
,
...
...
src/contrib/cudnn/cudnn_utils.cc
View file @
8214d6ca
...
@@ -13,17 +13,17 @@ namespace contrib {
...
@@ -13,17 +13,17 @@ namespace contrib {
// CuDNN Data Type
// CuDNN Data Type
cudnnDataType_t
CuDNNDataType
::
DLTypeToCuDNNType
(
const
DLDataType
&
dtype
)
{
cudnnDataType_t
CuDNNDataType
::
DLTypeToCuDNNType
(
const
DLDataType
&
dtype
)
{
switch
(
dtype
.
code
)
{
switch
(
dtype
.
code
)
{
case
kInt
:
case
k
DL
Int
:
if
(
dtype
.
bits
==
8
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_INT8
;
if
(
dtype
.
bits
==
8
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_INT8
;
else
if
(
dtype
.
bits
==
32
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_INT32
;
else
if
(
dtype
.
bits
==
32
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_INT32
;
else
if
(
dtype
.
bits
==
8
&&
dtype
.
lanes
==
4
)
return
CUDNN_DATA_INT8x4
;
else
if
(
dtype
.
bits
==
8
&&
dtype
.
lanes
==
4
)
return
CUDNN_DATA_INT8x4
;
else
else
LOG
(
FATAL
)
<<
"Unsupported type"
;
LOG
(
FATAL
)
<<
"Unsupported type"
;
break
;
break
;
case
kUInt
:
case
k
DL
UInt
:
LOG
(
FATAL
)
<<
"Unsupported type"
;
LOG
(
FATAL
)
<<
"Unsupported type"
;
break
;
break
;
case
kFloat
:
case
k
DL
Float
:
if
(
dtype
.
bits
==
32
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_FLOAT
;
if
(
dtype
.
bits
==
32
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_FLOAT
;
else
if
(
dtype
.
bits
==
64
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_DOUBLE
;
else
if
(
dtype
.
bits
==
64
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_DOUBLE
;
else
if
(
dtype
.
bits
==
16
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_HALF
;
else
if
(
dtype
.
bits
==
16
&&
dtype
.
lanes
==
1
)
return
CUDNN_DATA_HALF
;
...
...
src/contrib/nnpack/convolution.cc
View file @
8214d6ca
...
@@ -44,10 +44,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
...
@@ -44,10 +44,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
CHECK
(
kernel
->
strides
==
nullptr
);
CHECK
(
kernel
->
strides
==
nullptr
);
CHECK
(
bias
->
strides
==
nullptr
);
CHECK
(
bias
->
strides
==
nullptr
);
CHECK
(
TypeMatch
(
input
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
input
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
kernel
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
kernel
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
bias
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
bias
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
output
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
output
->
dtype
,
k
DL
Float
,
32
));
nnp_convolution_inference
(
nnp_convolution_algorithm_auto
,
nnp_convolution_inference
(
nnp_convolution_algorithm_auto
,
nnp_convolution_transform_strategy_block_based
,
nnp_convolution_transform_strategy_block_based
,
...
@@ -102,10 +102,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
...
@@ -102,10 +102,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
CHECK
(
kernel
->
strides
==
nullptr
);
CHECK
(
kernel
->
strides
==
nullptr
);
CHECK
(
bias
->
strides
==
nullptr
);
CHECK
(
bias
->
strides
==
nullptr
);
CHECK
(
TypeMatch
(
input
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
input
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
kernel
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
kernel
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
bias
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
bias
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
output
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
output
->
dtype
,
k
DL
Float
,
32
));
nnp_convolution_output
(
nnp_convolution_algorithm_auto
,
nnp_convolution_output
(
nnp_convolution_algorithm_auto
,
batch_size
,
batch_size
,
...
...
src/contrib/nnpack/fully_connected.cc
View file @
8214d6ca
...
@@ -29,9 +29,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
...
@@ -29,9 +29,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
CHECK
(
C
->
strides
==
nullptr
);
CHECK
(
C
->
strides
==
nullptr
);
CHECK
(
B
->
strides
==
nullptr
);
CHECK
(
B
->
strides
==
nullptr
);
CHECK
(
A
->
strides
==
nullptr
);
CHECK
(
A
->
strides
==
nullptr
);
CHECK
(
TypeMatch
(
A
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
A
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
B
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
B
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
C
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
C
->
dtype
,
k
DL
Float
,
32
));
nnp_fully_connected_inference
(
B
->
shape
[
1
],
nnp_fully_connected_inference
(
B
->
shape
[
1
],
B
->
shape
[
0
],
B
->
shape
[
0
],
...
@@ -58,9 +58,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
...
@@ -58,9 +58,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
CHECK
(
C
->
strides
==
nullptr
);
CHECK
(
C
->
strides
==
nullptr
);
CHECK
(
B
->
strides
==
nullptr
);
CHECK
(
B
->
strides
==
nullptr
);
CHECK
(
A
->
strides
==
nullptr
);
CHECK
(
A
->
strides
==
nullptr
);
CHECK
(
TypeMatch
(
A
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
A
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
B
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
B
->
dtype
,
k
DL
Float
,
32
));
CHECK
(
TypeMatch
(
C
->
dtype
,
kFloat
,
32
));
CHECK
(
TypeMatch
(
C
->
dtype
,
k
DL
Float
,
32
));
nnp_fully_connected_output
(
A
->
shape
[
0
],
nnp_fully_connected_output
(
A
->
shape
[
0
],
B
->
shape
[
1
],
B
->
shape
[
1
],
...
...
src/pass/lower_tvm_builtin.cc
View file @
8214d6ca
...
@@ -72,7 +72,7 @@ class BuiltinLower : public IRMutator {
...
@@ -72,7 +72,7 @@ class BuiltinLower : public IRMutator {
int64_t
nbytes
=
GetVectorBytes
(
op
->
type
);
int64_t
nbytes
=
GetVectorBytes
(
op
->
type
);
if
(
device_type_
.
defined
())
{
if
(
device_type_
.
defined
())
{
if
(
arith
::
GetConst
(
device_type_
,
&
dev_type
))
{
if
(
arith
::
GetConst
(
device_type_
,
&
dev_type
))
{
if
(
dev_type
==
kCPU
)
{
if
(
dev_type
==
k
DL
CPU
)
{
int32_t
constant_size
=
op
->
constant_allocation_size
();
int32_t
constant_size
=
op
->
constant_allocation_size
();
if
(
constant_size
>
0
&&
constant_size
*
nbytes
<
runtime
::
kMaxStackAlloca
)
{
if
(
constant_size
>
0
&&
constant_size
*
nbytes
<
runtime
::
kMaxStackAlloca
)
{
return
stmt
;
return
stmt
;
...
...
src/pass/make_api.cc
View file @
8214d6ca
...
@@ -107,12 +107,13 @@ LoweredFunc MakeAPI(Stmt body,
...
@@ -107,12 +107,13 @@ LoweredFunc MakeAPI(Stmt body,
}
else
if
(
t
.
is_int
()
||
t
.
is_uint
())
{
}
else
if
(
t
.
is_int
()
||
t
.
is_uint
())
{
std
::
ostringstream
msg
;
std
::
ostringstream
msg
;
msg
<<
name
<<
": Expect arg["
<<
i
<<
"] to be int"
;
msg
<<
name
<<
": Expect arg["
<<
i
<<
"] to be int"
;
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kInt
,
msg
.
str
(),
nop
));
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
k
DL
Int
,
msg
.
str
(),
nop
));
}
else
{
}
else
{
CHECK
(
t
.
is_float
());
CHECK
(
t
.
is_float
());
std
::
ostringstream
msg
;
std
::
ostringstream
msg
;
msg
<<
name
<<
": Expect arg["
<<
i
<<
"] to be float"
;
msg
<<
name
<<
": Expect arg["
<<
i
<<
"] to be float"
;
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kFloat
,
msg
.
str
(),
nop
));
seq_check
.
emplace_back
(
AssertStmt
::
make
(
tcode
==
kDLFloat
,
msg
.
str
(),
nop
));
}
}
}
else
{
}
else
{
args
.
push_back
(
v_arg
);
args
.
push_back
(
v_arg
);
...
@@ -148,7 +149,7 @@ LoweredFunc MakeAPI(Stmt body,
...
@@ -148,7 +149,7 @@ LoweredFunc MakeAPI(Stmt body,
seq_check
.
push_back
(
AttrStmt
::
make
(
seq_check
.
push_back
(
AttrStmt
::
make
(
node
,
attr
::
device_context_type
,
device_type
,
nop
));
node
,
attr
::
device_context_type
,
device_type
,
nop
));
Stmt
set_device
=
IfThenElse
::
make
(
Stmt
set_device
=
IfThenElse
::
make
(
device_type
!=
kCPU
,
Evaluate
::
make
(
Call
::
make
(
device_type
!=
k
DL
CPU
,
Evaluate
::
make
(
Call
::
make
(
Int
(
32
),
intrinsic
::
tvm_call_packed
,
Int
(
32
),
intrinsic
::
tvm_call_packed
,
{
StringImm
::
make
(
runtime
::
symbol
::
tvm_set_device
),
{
StringImm
::
make
(
runtime
::
symbol
::
tvm_set_device
),
device_type
,
device_id
},
Call
::
Intrinsic
)));
device_type
,
device_id
},
Call
::
Intrinsic
)));
...
...
src/runtime/c_runtime_api.cc
View file @
8214d6ca
...
@@ -25,12 +25,12 @@ namespace runtime {
...
@@ -25,12 +25,12 @@ namespace runtime {
*/
*/
inline
std
::
string
DeviceName
(
int
type
)
{
inline
std
::
string
DeviceName
(
int
type
)
{
switch
(
type
)
{
switch
(
type
)
{
case
kCPU
:
return
"cpu"
;
case
k
DL
CPU
:
return
"cpu"
;
case
kGPU
:
return
"gpu"
;
case
k
DL
GPU
:
return
"gpu"
;
case
kOpenCL
:
return
"opencl"
;
case
k
DL
OpenCL
:
return
"opencl"
;
case
kMetal
:
return
"metal"
;
case
k
DL
Metal
:
return
"metal"
;
case
kVPI
:
return
"vpi"
;
case
k
DL
VPI
:
return
"vpi"
;
case
kROCM
:
return
"rocm"
;
case
k
DL
ROCM
:
return
"rocm"
;
case
kExtDev
:
return
"ext_dev"
;
case
kExtDev
:
return
"ext_dev"
;
default
:
LOG
(
FATAL
)
<<
"unknown type ="
<<
type
;
return
"Unknown"
;
default
:
LOG
(
FATAL
)
<<
"unknown type ="
<<
type
;
return
"Unknown"
;
}
}
...
@@ -126,7 +126,7 @@ inline void TVMArrayFree_(TVMArray* arr) {
...
@@ -126,7 +126,7 @@ inline void TVMArrayFree_(TVMArray* arr) {
inline
void
VerifyType
(
int
dtype_code
,
int
dtype_bits
,
int
dtype_lanes
)
{
inline
void
VerifyType
(
int
dtype_code
,
int
dtype_bits
,
int
dtype_lanes
)
{
CHECK_GE
(
dtype_lanes
,
1
);
CHECK_GE
(
dtype_lanes
,
1
);
if
(
dtype_code
==
kFloat
)
{
if
(
dtype_code
==
k
DL
Float
)
{
CHECK_EQ
(
dtype_bits
%
32
,
0
);
CHECK_EQ
(
dtype_bits
%
32
,
0
);
}
else
{
}
else
{
CHECK_EQ
(
dtype_bits
%
8
,
0
);
CHECK_EQ
(
dtype_bits
%
8
,
0
);
...
@@ -382,10 +382,10 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
...
@@ -382,10 +382,10 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
CHECK_EQ
(
from_size
,
to_size
)
CHECK_EQ
(
from_size
,
to_size
)
<<
"TVMArrayCopyFromTo: The size must exactly match"
;
<<
"TVMArrayCopyFromTo: The size must exactly match"
;
TVMContext
ctx
=
from
->
ctx
;
TVMContext
ctx
=
from
->
ctx
;
if
(
ctx
.
device_type
==
kCPU
)
{
if
(
ctx
.
device_type
==
k
DL
CPU
)
{
ctx
=
to
->
ctx
;
ctx
=
to
->
ctx
;
}
else
{
}
else
{
CHECK
(
to
->
ctx
.
device_type
==
kCPU
||
CHECK
(
to
->
ctx
.
device_type
==
k
DL
CPU
||
to
->
ctx
.
device_type
==
from
->
ctx
.
device_type
)
to
->
ctx
.
device_type
==
from
->
ctx
.
device_type
)
<<
"Can not copy across different ctx types directly"
;
<<
"Can not copy across different ctx types directly"
;
}
}
...
@@ -401,7 +401,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle,
...
@@ -401,7 +401,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle,
size_t
nbytes
)
{
size_t
nbytes
)
{
API_BEGIN
();
API_BEGIN
();
TVMContext
cpu_ctx
;
TVMContext
cpu_ctx
;
cpu_ctx
.
device_type
=
kCPU
;
cpu_ctx
.
device_type
=
k
DL
CPU
;
cpu_ctx
.
device_id
=
0
;
cpu_ctx
.
device_id
=
0
;
size_t
arr_size
=
GetDataSize
(
handle
);
size_t
arr_size
=
GetDataSize
(
handle
);
CHECK_EQ
(
arr_size
,
nbytes
)
CHECK_EQ
(
arr_size
,
nbytes
)
...
@@ -418,7 +418,7 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle,
...
@@ -418,7 +418,7 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle,
size_t
nbytes
)
{
size_t
nbytes
)
{
API_BEGIN
();
API_BEGIN
();
TVMContext
cpu_ctx
;
TVMContext
cpu_ctx
;
cpu_ctx
.
device_type
=
kCPU
;
cpu_ctx
.
device_type
=
k
DL
CPU
;
cpu_ctx
.
device_id
=
0
;
cpu_ctx
.
device_id
=
0
;
size_t
arr_size
=
GetDataSize
(
handle
);
size_t
arr_size
=
GetDataSize
(
handle
);
CHECK_EQ
(
arr_size
,
nbytes
)
CHECK_EQ
(
arr_size
,
nbytes
)
...
...
src/runtime/cpu_device_api.cc
View file @
8214d6ca
...
@@ -68,7 +68,7 @@ class CPUDeviceAPI final : public DeviceAPI {
...
@@ -68,7 +68,7 @@ class CPUDeviceAPI final : public DeviceAPI {
struct
CPUWorkspacePool
:
public
WorkspacePool
{
struct
CPUWorkspacePool
:
public
WorkspacePool
{
CPUWorkspacePool
()
:
CPUWorkspacePool
()
:
WorkspacePool
(
kCPU
,
CPUDeviceAPI
::
Global
())
{}
WorkspacePool
(
k
DL
CPU
,
CPUDeviceAPI
::
Global
())
{}
};
};
void
*
CPUDeviceAPI
::
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
{
void
*
CPUDeviceAPI
::
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
{
...
...
src/runtime/cuda/cuda_device_api.cc
View file @
8214d6ca
...
@@ -79,7 +79,7 @@ class CUDADeviceAPI final : public DeviceAPI {
...
@@ -79,7 +79,7 @@ class CUDADeviceAPI final : public DeviceAPI {
cudaStream_t
cu_stream
=
static_cast
<
cudaStream_t
>
(
stream
);
cudaStream_t
cu_stream
=
static_cast
<
cudaStream_t
>
(
stream
);
from
=
static_cast
<
const
char
*>
(
from
)
+
from_offset
;
from
=
static_cast
<
const
char
*>
(
from
)
+
from_offset
;
to
=
static_cast
<
char
*>
(
to
)
+
to_offset
;
to
=
static_cast
<
char
*>
(
to
)
+
to_offset
;
if
(
ctx_from
.
device_type
==
k
GPU
&&
ctx_to
.
device_type
==
k
GPU
)
{
if
(
ctx_from
.
device_type
==
k
DLGPU
&&
ctx_to
.
device_type
==
kDL
GPU
)
{
CUDA_CALL
(
cudaSetDevice
(
ctx_from
.
device_id
));
CUDA_CALL
(
cudaSetDevice
(
ctx_from
.
device_id
));
if
(
ctx_from
.
device_id
==
ctx_to
.
device_id
)
{
if
(
ctx_from
.
device_id
==
ctx_to
.
device_id
)
{
GPUCopy
(
from
,
to
,
size
,
cudaMemcpyDeviceToDevice
,
cu_stream
);
GPUCopy
(
from
,
to
,
size
,
cudaMemcpyDeviceToDevice
,
cu_stream
);
...
@@ -88,10 +88,10 @@ class CUDADeviceAPI final : public DeviceAPI {
...
@@ -88,10 +88,10 @@ class CUDADeviceAPI final : public DeviceAPI {
from
,
ctx_from
.
device_id
,
from
,
ctx_from
.
device_id
,
size
,
cu_stream
);
size
,
cu_stream
);
}
}
}
else
if
(
ctx_from
.
device_type
==
k
GPU
&&
ctx_to
.
device_type
==
k
CPU
)
{
}
else
if
(
ctx_from
.
device_type
==
k
DLGPU
&&
ctx_to
.
device_type
==
kDL
CPU
)
{
CUDA_CALL
(
cudaSetDevice
(
ctx_from
.
device_id
));
CUDA_CALL
(
cudaSetDevice
(
ctx_from
.
device_id
));
GPUCopy
(
from
,
to
,
size
,
cudaMemcpyDeviceToHost
,
cu_stream
);
GPUCopy
(
from
,
to
,
size
,
cudaMemcpyDeviceToHost
,
cu_stream
);
}
else
if
(
ctx_from
.
device_type
==
k
CPU
&&
ctx_to
.
device_type
==
k
GPU
)
{
}
else
if
(
ctx_from
.
device_type
==
k
DLCPU
&&
ctx_to
.
device_type
==
kDL
GPU
)
{
CUDA_CALL
(
cudaSetDevice
(
ctx_to
.
device_id
));
CUDA_CALL
(
cudaSetDevice
(
ctx_to
.
device_id
));
GPUCopy
(
from
,
to
,
size
,
cudaMemcpyHostToDevice
,
cu_stream
);
GPUCopy
(
from
,
to
,
size
,
cudaMemcpyHostToDevice
,
cu_stream
);
}
else
{
}
else
{
...
@@ -140,7 +140,7 @@ class CUDADeviceAPI final : public DeviceAPI {
...
@@ -140,7 +140,7 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef
dmlc
::
ThreadLocalStore
<
CUDAThreadEntry
>
CUDAThreadStore
;
typedef
dmlc
::
ThreadLocalStore
<
CUDAThreadEntry
>
CUDAThreadStore
;
CUDAThreadEntry
::
CUDAThreadEntry
()
CUDAThreadEntry
::
CUDAThreadEntry
()
:
pool
(
kGPU
,
CUDADeviceAPI
::
Global
())
{
:
pool
(
k
DL
GPU
,
CUDADeviceAPI
::
Global
())
{
}
}
CUDAThreadEntry
*
CUDAThreadEntry
::
ThreadLocal
()
{
CUDAThreadEntry
*
CUDAThreadEntry
::
ThreadLocal
()
{
...
...
src/runtime/graph/graph_runtime.cc
View file @
8214d6ca
...
@@ -462,7 +462,7 @@ void GraphRuntime::SetupStorage() {
...
@@ -462,7 +462,7 @@ void GraphRuntime::SetupStorage() {
int64_t
shape
[]
=
{
static_cast
<
int64_t
>
(
pool_entry_bytes
[
i
]
+
3
)
/
4
};
int64_t
shape
[]
=
{
static_cast
<
int64_t
>
(
pool_entry_bytes
[
i
]
+
3
)
/
4
};
DLTensor
*
tensor
;
DLTensor
*
tensor
;
TVM_CCALL
(
TVMArrayAlloc
(
TVM_CCALL
(
TVMArrayAlloc
(
shape
,
1
,
kFloat
,
32
,
1
,
ctx_
.
device_type
,
ctx_
.
device_id
,
&
tensor
));
shape
,
1
,
k
DL
Float
,
32
,
1
,
ctx_
.
device_type
,
ctx_
.
device_id
,
&
tensor
));
storage_pool_
.
push_back
(
tensor
);
storage_pool_
.
push_back
(
tensor
);
}
}
// Assign the pooled entries.
// Assign the pooled entries.
...
...
src/runtime/metal/metal_common.h
View file @
8214d6ca
...
@@ -45,14 +45,14 @@ class MetalWorkspace final : public DeviceAPI {
...
@@ -45,14 +45,14 @@ class MetalWorkspace final : public DeviceAPI {
~
MetalWorkspace
();
~
MetalWorkspace
();
// Get command queue for given context.
// Get command queue for given context.
id
<
MTLCommandQueue
>
GetCommandQueue
(
TVMContext
ctx
)
{
id
<
MTLCommandQueue
>
GetCommandQueue
(
TVMContext
ctx
)
{
CHECK_EQ
(
ctx
.
device_type
,
kMetal
);
CHECK_EQ
(
ctx
.
device_type
,
k
DL
Metal
);
CHECK
(
ctx
.
device_id
>=
0
&&
static_cast
<
size_t
>
(
ctx
.
device_id
)
<
queues
.
size
())
CHECK
(
ctx
.
device_id
>=
0
&&
static_cast
<
size_t
>
(
ctx
.
device_id
)
<
queues
.
size
())
<<
"Invalid Metal device_id="
<<
ctx
.
device_id
;
<<
"Invalid Metal device_id="
<<
ctx
.
device_id
;
return
queues
[
ctx
.
device_id
];
return
queues
[
ctx
.
device_id
];
}
}
// Get device for given context
// Get device for given context
id
<
MTLDevice
>
GetDevice
(
TVMContext
ctx
)
{
id
<
MTLDevice
>
GetDevice
(
TVMContext
ctx
)
{
CHECK_EQ
(
ctx
.
device_type
,
kMetal
);
CHECK_EQ
(
ctx
.
device_type
,
k
DL
Metal
);
CHECK
(
ctx
.
device_id
>=
0
&&
static_cast
<
size_t
>
(
ctx
.
device_id
)
<
devices
.
size
())
CHECK
(
ctx
.
device_id
>=
0
&&
static_cast
<
size_t
>
(
ctx
.
device_id
)
<
devices
.
size
())
<<
"Invalid Metal device_id="
<<
ctx
.
device_id
;
<<
"Invalid Metal device_id="
<<
ctx
.
device_id
;
return
devices
[
ctx
.
device_id
];
return
devices
[
ctx
.
device_id
];
...
@@ -91,9 +91,9 @@ class MetalThreadEntry {
...
@@ -91,9 +91,9 @@ class MetalThreadEntry {
WorkspacePool
pool
;
WorkspacePool
pool
;
// constructor
// constructor
MetalThreadEntry
()
MetalThreadEntry
()
:
pool
(
static_cast
<
DLDeviceType
>
(
kMetal
),
MetalWorkspace
::
Global
())
{
:
pool
(
static_cast
<
DLDeviceType
>
(
k
DL
Metal
),
MetalWorkspace
::
Global
())
{
context
.
device_id
=
0
;
context
.
device_id
=
0
;
context
.
device_type
=
static_cast
<
DLDeviceType
>
(
kMetal
);
context
.
device_type
=
static_cast
<
DLDeviceType
>
(
k
DL
Metal
);
}
}
~
MetalThreadEntry
();
~
MetalThreadEntry
();
// Get temp buffer with at least size under ctx.
// Get temp buffer with at least size under ctx.
...
...
src/runtime/metal/metal_device_api.mm
View file @
8214d6ca
...
@@ -150,13 +150,13 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
...
@@ -150,13 +150,13 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
this->Init();
this->Init();
CHECK(stream == nullptr);
CHECK(stream == nullptr);
TVMContext ctx = ctx_from;
TVMContext ctx = ctx_from;
if (ctx_from.device_type == kCPU) ctx = ctx_to;
if (ctx_from.device_type == k
DL
CPU) ctx = ctx_to;
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLCommandBuffer> cb = [queue commandBuffer];
int from_dev_type = static_cast<int>(ctx_from.device_type);
int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type);
if (from_dev_type == k
Metal && to_dev_type == k
Metal) {
if (from_dev_type == k
DLMetal && to_dev_type == kDL
Metal) {
CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
<< "Metal disallow cross device copy.";
<< "Metal disallow cross device copy.";
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
...
@@ -167,7 +167,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
...
@@ -167,7 +167,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
size:size];
size:size];
[encoder endEncoding];
[encoder endEncoding];
[cb commit];
[cb commit];
} else if (from_dev_type == k
Metal && to_dev_type == k
CPU) {
} else if (from_dev_type == k
DLMetal && to_dev_type == kDL
CPU) {
// copy to a local buffer before get into global buffer.
// copy to a local buffer before get into global buffer.
id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
if (from_buf.storageMode != MTLStorageModeShared) {
if (from_buf.storageMode != MTLStorageModeShared) {
...
@@ -190,7 +190,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
...
@@ -190,7 +190,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
static_cast<char*>([from_buf contents]) + from_offset,
static_cast<char*>([from_buf contents]) + from_offset,
size);
size);
}
}
} else if (from_dev_type == k
CPU && to_dev_type == k
Metal) {
} else if (from_dev_type == k
DLCPU && to_dev_type == kDL
Metal) {
id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
if (to_buf.storageMode != MTLStorageModeShared) {
if (to_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
...
...
src/runtime/opencl/opencl_common.h
View file @
8214d6ca
...
@@ -133,7 +133,7 @@ class OpenCLWorkspace final : public DeviceAPI {
...
@@ -133,7 +133,7 @@ class OpenCLWorkspace final : public DeviceAPI {
void
Init
();
void
Init
();
// get the queue of the context
// get the queue of the context
cl_command_queue
GetQueue
(
TVMContext
ctx
)
{
cl_command_queue
GetQueue
(
TVMContext
ctx
)
{
CHECK_EQ
(
ctx
.
device_type
,
kOpenCL
);
CHECK_EQ
(
ctx
.
device_type
,
k
DL
OpenCL
);
this
->
Init
();
this
->
Init
();
CHECK
(
ctx
.
device_id
>=
0
&&
static_cast
<
size_t
>
(
ctx
.
device_id
)
<
queues
.
size
())
CHECK
(
ctx
.
device_id
>=
0
&&
static_cast
<
size_t
>
(
ctx
.
device_id
)
<
queues
.
size
())
<<
"Invalid OpenCL device_id="
<<
ctx
.
device_id
;
<<
"Invalid OpenCL device_id="
<<
ctx
.
device_id
;
...
@@ -178,9 +178,9 @@ class OpenCLThreadEntry {
...
@@ -178,9 +178,9 @@ class OpenCLThreadEntry {
WorkspacePool
pool
;
WorkspacePool
pool
;
// constructor
// constructor
OpenCLThreadEntry
()
OpenCLThreadEntry
()
:
pool
(
kOpenCL
,
OpenCLWorkspace
::
Global
())
{
:
pool
(
k
DL
OpenCL
,
OpenCLWorkspace
::
Global
())
{
context
.
device_id
=
0
;
context
.
device_id
=
0
;
context
.
device_type
=
kOpenCL
;
context
.
device_type
=
k
DL
OpenCL
;
}
}
// get the global workspace
// get the global workspace
static
OpenCLThreadEntry
*
ThreadLocal
();
static
OpenCLThreadEntry
*
ThreadLocal
();
...
...
src/runtime/opencl/opencl_device_api.cc
View file @
8214d6ca
...
@@ -76,13 +76,13 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from,
...
@@ -76,13 +76,13 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from,
TVMStreamHandle
stream
)
{
TVMStreamHandle
stream
)
{
this
->
Init
();
this
->
Init
();
CHECK
(
stream
==
nullptr
);
CHECK
(
stream
==
nullptr
);
if
(
ctx_from
.
device_type
==
k
OpenCL
&&
ctx_to
.
device_type
==
k
OpenCL
)
{
if
(
ctx_from
.
device_type
==
k
DLOpenCL
&&
ctx_to
.
device_type
==
kDL
OpenCL
)
{
OPENCL_CALL
(
clEnqueueCopyBuffer
(
OPENCL_CALL
(
clEnqueueCopyBuffer
(
this
->
GetQueue
(
ctx_to
),
this
->
GetQueue
(
ctx_to
),
static_cast
<
cl_mem
>
((
void
*
)
from
),
// NOLINT(*)
static_cast
<
cl_mem
>
((
void
*
)
from
),
// NOLINT(*)
static_cast
<
cl_mem
>
(
to
),
static_cast
<
cl_mem
>
(
to
),
from_offset
,
to_offset
,
size
,
0
,
nullptr
,
nullptr
));
from_offset
,
to_offset
,
size
,
0
,
nullptr
,
nullptr
));
}
else
if
(
ctx_from
.
device_type
==
k
OpenCL
&&
ctx_to
.
device_type
==
k
CPU
)
{
}
else
if
(
ctx_from
.
device_type
==
k
DLOpenCL
&&
ctx_to
.
device_type
==
kDL
CPU
)
{
OPENCL_CALL
(
clEnqueueReadBuffer
(
OPENCL_CALL
(
clEnqueueReadBuffer
(
this
->
GetQueue
(
ctx_from
),
this
->
GetQueue
(
ctx_from
),
static_cast
<
cl_mem
>
((
void
*
)
from
),
// NOLINT(*)
static_cast
<
cl_mem
>
((
void
*
)
from
),
// NOLINT(*)
...
@@ -90,7 +90,7 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from,
...
@@ -90,7 +90,7 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from,
static_cast
<
char
*>
(
to
)
+
to_offset
,
static_cast
<
char
*>
(
to
)
+
to_offset
,
0
,
nullptr
,
nullptr
));
0
,
nullptr
,
nullptr
));
OPENCL_CALL
(
clFinish
(
this
->
GetQueue
(
ctx_from
)));
OPENCL_CALL
(
clFinish
(
this
->
GetQueue
(
ctx_from
)));
}
else
if
(
ctx_from
.
device_type
==
k
CPU
&&
ctx_to
.
device_type
==
k
OpenCL
)
{
}
else
if
(
ctx_from
.
device_type
==
k
DLCPU
&&
ctx_to
.
device_type
==
kDL
OpenCL
)
{
OPENCL_CALL
(
clEnqueueWriteBuffer
(
OPENCL_CALL
(
clEnqueueWriteBuffer
(
this
->
GetQueue
(
ctx_to
),
this
->
GetQueue
(
ctx_to
),
static_cast
<
cl_mem
>
(
to
),
static_cast
<
cl_mem
>
(
to
),
...
...
src/runtime/pack_args.h
View file @
8214d6ca
...
@@ -104,12 +104,12 @@ enum ArgConvertCode {
...
@@ -104,12 +104,12 @@ enum ArgConvertCode {
inline
ArgConvertCode
GetArgConvertCode
(
TVMType
t
)
{
inline
ArgConvertCode
GetArgConvertCode
(
TVMType
t
)
{
CHECK_EQ
(
t
.
lanes
,
1U
)
CHECK_EQ
(
t
.
lanes
,
1U
)
<<
"Cannot pass vector type argument to devic function for now"
;
<<
"Cannot pass vector type argument to devic function for now"
;
if
(
t
.
code
==
kInt
)
{
if
(
t
.
code
==
k
DL
Int
)
{
if
(
t
.
bits
==
64U
)
return
INT64_TO_INT64
;
if
(
t
.
bits
==
64U
)
return
INT64_TO_INT64
;
if
(
t
.
bits
==
32U
)
return
INT64_TO_INT32
;
if
(
t
.
bits
==
32U
)
return
INT64_TO_INT32
;
}
else
if
(
t
.
code
==
kUInt
)
{
}
else
if
(
t
.
code
==
k
DL
UInt
)
{
if
(
t
.
bits
==
32U
)
return
INT64_TO_UINT32
;
if
(
t
.
bits
==
32U
)
return
INT64_TO_UINT32
;
}
else
if
(
t
.
code
==
kFloat
)
{
}
else
if
(
t
.
code
==
k
DL
Float
)
{
if
(
t
.
bits
==
64U
)
return
FLOAT64_TO_FLOAT64
;
if
(
t
.
bits
==
64U
)
return
FLOAT64_TO_FLOAT64
;
if
(
t
.
bits
==
32U
)
return
FLOAT64_TO_FLOAT32
;
if
(
t
.
bits
==
32U
)
return
FLOAT64_TO_FLOAT32
;
}
else
if
(
t
.
code
==
kHandle
)
{
}
else
if
(
t
.
code
==
kHandle
)
{
...
...
src/runtime/rocm/rocm_device_api.cc
View file @
8214d6ca
...
@@ -77,7 +77,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
...
@@ -77,7 +77,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
hipStream_t
hip_stream
=
static_cast
<
hipStream_t
>
(
stream
);
hipStream_t
hip_stream
=
static_cast
<
hipStream_t
>
(
stream
);
from
=
static_cast
<
const
char
*>
(
from
)
+
from_offset
;
from
=
static_cast
<
const
char
*>
(
from
)
+
from_offset
;
to
=
static_cast
<
char
*>
(
to
)
+
to_offset
;
to
=
static_cast
<
char
*>
(
to
)
+
to_offset
;
if
(
ctx_from
.
device_type
==
k
ROCM
&&
ctx_to
.
device_type
==
k
ROCM
)
{
if
(
ctx_from
.
device_type
==
k
DLROCM
&&
ctx_to
.
device_type
==
kDL
ROCM
)
{
ROCM_CALL
(
hipSetDevice
(
ctx_from
.
device_id
));
ROCM_CALL
(
hipSetDevice
(
ctx_from
.
device_id
));
if
(
ctx_from
.
device_id
==
ctx_to
.
device_id
)
{
if
(
ctx_from
.
device_id
==
ctx_to
.
device_id
)
{
GPUCopy
(
from
,
to
,
size
,
hipMemcpyDeviceToDevice
,
hip_stream
);
GPUCopy
(
from
,
to
,
size
,
hipMemcpyDeviceToDevice
,
hip_stream
);
...
@@ -86,10 +86,10 @@ class ROCMDeviceAPI final : public DeviceAPI {
...
@@ -86,10 +86,10 @@ class ROCMDeviceAPI final : public DeviceAPI {
from
,
ctx_from
.
device_id
,
from
,
ctx_from
.
device_id
,
size
,
hip_stream
);
size
,
hip_stream
);
}
}
}
else
if
(
ctx_from
.
device_type
==
k
ROCM
&&
ctx_to
.
device_type
==
k
CPU
)
{
}
else
if
(
ctx_from
.
device_type
==
k
DLROCM
&&
ctx_to
.
device_type
==
kDL
CPU
)
{
ROCM_CALL
(
hipSetDevice
(
ctx_from
.
device_id
));
ROCM_CALL
(
hipSetDevice
(
ctx_from
.
device_id
));
GPUCopy
(
from
,
to
,
size
,
hipMemcpyDeviceToHost
,
hip_stream
);
GPUCopy
(
from
,
to
,
size
,
hipMemcpyDeviceToHost
,
hip_stream
);
}
else
if
(
ctx_from
.
device_type
==
k
CPU
&&
ctx_to
.
device_type
==
k
ROCM
)
{
}
else
if
(
ctx_from
.
device_type
==
k
DLCPU
&&
ctx_to
.
device_type
==
kDL
ROCM
)
{
ROCM_CALL
(
hipSetDevice
(
ctx_to
.
device_id
));
ROCM_CALL
(
hipSetDevice
(
ctx_to
.
device_id
));
GPUCopy
(
from
,
to
,
size
,
hipMemcpyHostToDevice
,
hip_stream
);
GPUCopy
(
from
,
to
,
size
,
hipMemcpyHostToDevice
,
hip_stream
);
}
else
{
}
else
{
...
@@ -138,7 +138,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
...
@@ -138,7 +138,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
typedef
dmlc
::
ThreadLocalStore
<
ROCMThreadEntry
>
ROCMThreadStore
;
typedef
dmlc
::
ThreadLocalStore
<
ROCMThreadEntry
>
ROCMThreadStore
;
ROCMThreadEntry
::
ROCMThreadEntry
()
ROCMThreadEntry
::
ROCMThreadEntry
()
:
pool
(
kROCM
,
ROCMDeviceAPI
::
Global
())
{
:
pool
(
k
DL
ROCM
,
ROCMDeviceAPI
::
Global
())
{
}
}
ROCMThreadEntry
*
ROCMThreadEntry
::
ThreadLocal
()
{
ROCMThreadEntry
*
ROCMThreadEntry
::
ThreadLocal
()
{
...
...
src/runtime/rpc/rpc_device_api.cc
View file @
8214d6ca
...
@@ -55,12 +55,12 @@ class RPCDeviceAPI final : public DeviceAPI {
...
@@ -55,12 +55,12 @@ class RPCDeviceAPI final : public DeviceAPI {
static_cast
<
const
RemoteSpace
*>
(
to
)
->
data
,
to_offset
,
static_cast
<
const
RemoteSpace
*>
(
to
)
->
data
,
to_offset
,
size
,
ctx_from
,
ctx_to
,
stream
);
size
,
ctx_from
,
ctx_to
,
stream
);
}
else
if
(
from_dev_type
>
kRPCSessMask
&&
}
else
if
(
from_dev_type
>
kRPCSessMask
&&
to_dev_type
==
kCPU
)
{
to_dev_type
==
k
DL
CPU
)
{
GetSess
(
ctx_from
)
->
CopyFromRemote
(
GetSess
(
ctx_from
)
->
CopyFromRemote
(
static_cast
<
const
RemoteSpace
*>
(
from
)
->
data
,
from_offset
,
static_cast
<
const
RemoteSpace
*>
(
from
)
->
data
,
from_offset
,
to
,
to_offset
,
size
,
to
,
to_offset
,
size
,
ctx_from
);
ctx_from
);
}
else
if
(
from_dev_type
==
kCPU
&&
}
else
if
(
from_dev_type
==
k
DL
CPU
&&
to_dev_type
>
kRPCSessMask
)
{
to_dev_type
>
kRPCSessMask
)
{
GetSess
(
ctx_to
)
->
CopyToRemote
(
GetSess
(
ctx_to
)
->
CopyToRemote
(
(
void
*
)
from
,
from_offset
,
// NOLINT(*)
(
void
*
)
from
,
from_offset
,
// NOLINT(*)
...
...
src/runtime/rpc/rpc_session.cc
View file @
8214d6ca
...
@@ -162,9 +162,9 @@ class RPCSession::EventHandler {
...
@@ -162,9 +162,9 @@ class RPCSession::EventHandler {
int
tcode
=
type_codes
[
i
];
int
tcode
=
type_codes
[
i
];
TVMValue
value
=
arg_values
[
i
];
TVMValue
value
=
arg_values
[
i
];
switch
(
tcode
)
{
switch
(
tcode
)
{
case
kInt
:
case
k
DL
Int
:
case
kUInt
:
case
k
DL
UInt
:
case
kFloat
:
case
k
DL
Float
:
case
kTVMType
:
{
case
kTVMType
:
{
writer_
->
Write
(
&
value
,
sizeof
(
TVMValue
));
writer_
->
Write
(
&
value
,
sizeof
(
TVMValue
));
break
;
break
;
...
@@ -315,9 +315,9 @@ class RPCSession::EventHandler {
...
@@ -315,9 +315,9 @@ class RPCSession::EventHandler {
int
tcode
=
arg_buf_
->
tcode
[
arg_index_
];
int
tcode
=
arg_buf_
->
tcode
[
arg_index_
];
static_assert
(
sizeof
(
TVMValue
)
==
sizeof
(
uint64_t
),
"invariant"
);
static_assert
(
sizeof
(
TVMValue
)
==
sizeof
(
uint64_t
),
"invariant"
);
switch
(
tcode
)
{
switch
(
tcode
)
{
case
kInt
:
case
k
DL
Int
:
case
kUInt
:
case
k
DL
UInt
:
case
kFloat
:
case
k
DL
Float
:
case
kTVMType
:
case
kTVMType
:
case
kHandle
:
case
kHandle
:
case
kStr
:
case
kStr
:
...
@@ -352,9 +352,9 @@ class RPCSession::EventHandler {
...
@@ -352,9 +352,9 @@ class RPCSession::EventHandler {
TVMValue
&
value
=
arg_buf_
->
value
[
arg_index_
];
TVMValue
&
value
=
arg_buf_
->
value
[
arg_index_
];
if
(
arg_recv_stage_
==
0
)
{
if
(
arg_recv_stage_
==
0
)
{
switch
(
tcode
)
{
switch
(
tcode
)
{
case
kInt
:
case
k
DL
Int
:
case
kUInt
:
case
k
DL
UInt
:
case
kFloat
:
case
k
DL
Float
:
case
kTVMType
:
case
kTVMType
:
case
kTVMContext
:
{
case
kTVMContext
:
{
this
->
Read
(
&
value
,
sizeof
(
TVMValue
));
this
->
Read
(
&
value
,
sizeof
(
TVMValue
));
...
@@ -484,7 +484,7 @@ class RPCSession::EventHandler {
...
@@ -484,7 +484,7 @@ class RPCSession::EventHandler {
this
->
Read
(
&
offset
,
sizeof
(
offset
));
this
->
Read
(
&
offset
,
sizeof
(
offset
));
this
->
Read
(
&
size
,
sizeof
(
size
));
this
->
Read
(
&
size
,
sizeof
(
size
));
this
->
Read
(
&
ctx
,
sizeof
(
ctx
));
this
->
Read
(
&
ctx
,
sizeof
(
ctx
));
if
(
ctx
.
device_type
==
kCPU
)
{
if
(
ctx
.
device_type
==
k
DL
CPU
)
{
RPCCode
code
=
RPCCode
::
kCopyAck
;
RPCCode
code
=
RPCCode
::
kCopyAck
;
writer_
->
Write
(
&
code
,
sizeof
(
code
));
writer_
->
Write
(
&
code
,
sizeof
(
code
));
writer_
->
Write
(
reinterpret_cast
<
char
*>
(
handle
)
+
offset
,
size
);
writer_
->
Write
(
reinterpret_cast
<
char
*>
(
handle
)
+
offset
,
size
);
...
@@ -492,7 +492,7 @@ class RPCSession::EventHandler {
...
@@ -492,7 +492,7 @@ class RPCSession::EventHandler {
temp_data_
.
resize
(
size
+
1
);
temp_data_
.
resize
(
size
+
1
);
try
{
try
{
TVMContext
cpu_ctx
;
TVMContext
cpu_ctx
;
cpu_ctx
.
device_type
=
kCPU
;
cpu_ctx
.
device_type
=
k
DL
CPU
;
cpu_ctx
.
device_id
=
0
;
cpu_ctx
.
device_id
=
0
;
DeviceAPI
::
Get
(
ctx
)
->
CopyDataFromTo
(
DeviceAPI
::
Get
(
ctx
)
->
CopyDataFromTo
(
reinterpret_cast
<
void
*>
(
handle
),
offset
,
reinterpret_cast
<
void
*>
(
handle
),
offset
,
...
@@ -531,7 +531,7 @@ class RPCSession::EventHandler {
...
@@ -531,7 +531,7 @@ class RPCSession::EventHandler {
int
ret_tcode
=
kNull
;
int
ret_tcode
=
kNull
;
RPCCode
code
=
RPCCode
::
kReturn
;
RPCCode
code
=
RPCCode
::
kReturn
;
std
::
string
errmsg
;
std
::
string
errmsg
;
if
(
copy_ctx_
.
device_type
==
kCPU
)
{
if
(
copy_ctx_
.
device_type
==
k
DL
CPU
)
{
this
->
Read
(
this
->
Read
(
reinterpret_cast
<
char
*>
(
copy_handle_
)
+
copy_offset_
,
copy_size_
);
reinterpret_cast
<
char
*>
(
copy_handle_
)
+
copy_offset_
,
copy_size_
);
}
else
{
}
else
{
...
@@ -539,7 +539,7 @@ class RPCSession::EventHandler {
...
@@ -539,7 +539,7 @@ class RPCSession::EventHandler {
this
->
Read
(
&
temp_data_
[
0
],
copy_size_
);
this
->
Read
(
&
temp_data_
[
0
],
copy_size_
);
try
{
try
{
TVMContext
cpu_ctx
;
TVMContext
cpu_ctx
;
cpu_ctx
.
device_type
=
kCPU
;
cpu_ctx
.
device_type
=
k
DL
CPU
;
cpu_ctx
.
device_id
=
0
;
cpu_ctx
.
device_id
=
0
;
DeviceAPI
::
Get
(
copy_ctx_
)
->
CopyDataFromTo
(
DeviceAPI
::
Get
(
copy_ctx_
)
->
CopyDataFromTo
(
temp_data_
.
data
(),
0
,
temp_data_
.
data
(),
0
,
...
@@ -915,10 +915,10 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
...
@@ -915,10 +915,10 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
TVMContext
ctx_to
=
args
[
6
];
TVMContext
ctx_to
=
args
[
6
];
TVMStreamHandle
stream
=
args
[
7
];
TVMStreamHandle
stream
=
args
[
7
];
TVMContext
ctx
=
ctx_from
;
TVMContext
ctx
=
ctx_from
;
if
(
ctx
.
device_type
==
kCPU
)
{
if
(
ctx
.
device_type
==
k
DL
CPU
)
{
ctx
=
ctx_to
;
ctx
=
ctx_to
;
}
else
{
}
else
{
CHECK
(
ctx_to
.
device_type
==
kCPU
||
CHECK
(
ctx_to
.
device_type
==
k
DL
CPU
||
ctx_to
.
device_type
==
ctx_from
.
device_type
)
ctx_to
.
device_type
==
ctx_from
.
device_type
)
<<
"Can not copy across different ctx types directly"
;
<<
"Can not copy across different ctx types directly"
;
}
}
...
...
tests/cpp/packed_func_test.cc
View file @
8214d6ca
...
@@ -14,7 +14,7 @@ TEST(PackedFunc, Basic) {
...
@@ -14,7 +14,7 @@ TEST(PackedFunc, Basic) {
Var
v
=
PackedFunc
([
&
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Var
v
=
PackedFunc
([
&
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
CHECK
(
args
.
num_args
==
3
);
CHECK
(
args
.
num_args
==
3
);
CHECK
(
args
.
values
[
0
].
v_float64
==
1.0
);
CHECK
(
args
.
values
[
0
].
v_float64
==
1.0
);
CHECK
(
args
.
type_codes
[
0
]
==
kFloat
);
CHECK
(
args
.
type_codes
[
0
]
==
k
DL
Float
);
CHECK
(
args
.
values
[
1
].
v_handle
==
&
a
);
CHECK
(
args
.
values
[
1
].
v_handle
==
&
a
);
CHECK
(
args
.
type_codes
[
1
]
==
kArrayHandle
);
CHECK
(
args
.
type_codes
[
1
]
==
kArrayHandle
);
CHECK
(
args
.
values
[
2
].
v_handle
==
&
x
);
CHECK
(
args
.
values
[
2
].
v_handle
==
&
x
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment