Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
944de73b
Commit
944de73b
authored
Jan 28, 2018
by
Zhixun Tan
Committed by
Tianqi Chen
Jan 27, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add type code and bits to AllocWorkspace. (#831)
parent
eb8077ff
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
98 additions
and
36 deletions
+98
-36
include/tvm/runtime/c_backend_api.h
+8
-2
include/tvm/runtime/device_api.h
+6
-2
src/codegen/codegen_opengl.cc
+2
-0
src/codegen/stack_vm/codegen_stack_vm.cc
+3
-1
src/codegen/stack_vm/stack_vm.cc
+9
-6
src/pass/lower_tvm_builtin.cc
+6
-2
src/pass/split_host_device.cc
+5
-0
src/runtime/c_runtime_api.cc
+15
-4
src/runtime/cpu_device_api.cc
+4
-2
src/runtime/cuda/cuda_device_api.cc
+1
-1
src/runtime/metal/metal_common.h
+1
-1
src/runtime/metal/metal_device_api.mm
+3
-1
src/runtime/opencl/opencl_common.h
+1
-1
src/runtime/opencl/opencl_device_api.cc
+3
-1
src/runtime/opengl/opengl_common.h
+0
-2
src/runtime/opengl/opengl_device_api.cc
+0
-9
src/runtime/rocm/rocm_device_api.cc
+1
-1
tests/webgl/test_local_multi_stage.py
+30
-0
No files found.
include/tvm/runtime/c_backend_api.h
View file @
944de73b
...
@@ -44,14 +44,20 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
...
@@ -44,14 +44,20 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
*
*
* \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
* \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
*
*
* \param
size
The size of the space requested.
* \param
nbytes
The size of the space requested.
* \param device_type The device type which the space will be allocated.
* \param device_type The device type which the space will be allocated.
* \param device_id The device id which the space will be allocated.
* \param device_id The device id which the space will be allocated.
* \param dtype_code_hint The type code of the array elements. Only used in
* certain backends such as OpenGL.
* \param dtype_bits_hint The type bits of the array elements. Only used in
* certain backends such as OpenGL.
* \return nullptr when error is thrown, a valid ptr if success
* \return nullptr when error is thrown, a valid ptr if success
*/
*/
TVM_DLL
void
*
TVMBackendAllocWorkspace
(
int
device_type
,
TVM_DLL
void
*
TVMBackendAllocWorkspace
(
int
device_type
,
int
device_id
,
int
device_id
,
uint64_t
size
);
uint64_t
nbytes
,
int
dtype_code_hint
,
int
dtype_bits_hint
);
/*!
/*!
* \brief Backend function to free temporal workspace.
* \brief Backend function to free temporal workspace.
...
...
include/tvm/runtime/device_api.h
View file @
944de73b
...
@@ -114,9 +114,13 @@ class DeviceAPI {
...
@@ -114,9 +114,13 @@ class DeviceAPI {
* - Workspace should not overlap between different threads(i.e. be threadlocal)
* - Workspace should not overlap between different threads(i.e. be threadlocal)
*
*
* \param ctx The context of allocation.
* \param ctx The context of allocation.
* \param size The size to be allocated.
* \param nbytes The size to be allocated.
* \param type_hint The type of elements. Only needed by certain backends such
* as OpenGL, as nbytes is sufficient for most backends.
*/
*/
TVM_DLL
virtual
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
);
TVM_DLL
virtual
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
nbytes
,
TVMType
type_hint
=
{});
/*!
/*!
* \brief Free temporal workspace in backend execution.
* \brief Free temporal workspace in backend execution.
*
*
...
...
src/codegen/codegen_opengl.cc
View file @
944de73b
...
@@ -24,6 +24,8 @@ void CodeGenOpenGL::InitFuncState(LoweredFunc f) {
...
@@ -24,6 +24,8 @@ void CodeGenOpenGL::InitFuncState(LoweredFunc f) {
inputs_
.
clear
();
inputs_
.
clear
();
output_iter_var_
=
nullptr
;
output_iter_var_
=
nullptr
;
thread_extent_var_
=
""
;
thread_extent_var_
=
""
;
this
->
decl_stream
.
str
(
""
);
this
->
stream
.
str
(
""
);
}
}
void
CodeGenOpenGL
::
AddFunction
(
LoweredFunc
f
)
{
void
CodeGenOpenGL
::
AddFunction
(
LoweredFunc
f
)
{
...
...
src/codegen/stack_vm/codegen_stack_vm.cc
View file @
944de73b
...
@@ -197,10 +197,12 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
...
@@ -197,10 +197,12 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
vm_
.
stack_size
+=
size
;
vm_
.
stack_size
+=
size
;
this
->
PushOp
(
StackVM
::
TVM_STACK_ALLOCA_BY_8BYTE
,
static_cast
<
int
>
(
size
));
this
->
PushOp
(
StackVM
::
TVM_STACK_ALLOCA_BY_8BYTE
,
static_cast
<
int
>
(
size
));
}
else
if
(
op
->
name
==
"TVMBackendAllocWorkspace"
)
{
}
else
if
(
op
->
name
==
"TVMBackendAllocWorkspace"
)
{
CHECK_EQ
(
op
->
args
.
size
(),
3
U
);
CHECK_EQ
(
op
->
args
.
size
(),
5
U
);
this
->
Push
(
op
->
args
[
0
]);
this
->
Push
(
op
->
args
[
0
]);
this
->
Push
(
op
->
args
[
1
]);
this
->
Push
(
op
->
args
[
1
]);
this
->
Push
(
op
->
args
[
2
]);
this
->
Push
(
op
->
args
[
2
]);
this
->
Push
(
op
->
args
[
3
]);
this
->
Push
(
op
->
args
[
4
]);
this
->
PushOp
(
StackVM
::
TVM_DEVICE_ALLOCA
);
this
->
PushOp
(
StackVM
::
TVM_DEVICE_ALLOCA
);
}
else
if
(
op
->
name
==
"TVMBackendFreeWorkspace"
)
{
}
else
if
(
op
->
name
==
"TVMBackendFreeWorkspace"
)
{
CHECK_EQ
(
op
->
args
.
size
(),
3U
);
CHECK_EQ
(
op
->
args
.
size
(),
3U
);
...
...
src/codegen/stack_vm/stack_vm.cc
View file @
944de73b
...
@@ -455,12 +455,15 @@ void StackVM::Run(State* s) const {
...
@@ -455,12 +455,15 @@ void StackVM::Run(State* s) const {
break
;
break
;
}
}
case
TVM_DEVICE_ALLOCA
:
{
case
TVM_DEVICE_ALLOCA
:
{
int
device_type
=
static_cast
<
int
>
(
stack
[
sp
-
2
].
v_int64
);
int
device_type
=
static_cast
<
int
>
(
stack
[
sp
-
4
].
v_int64
);
int
device_id
=
static_cast
<
int
>
(
stack
[
sp
-
1
].
v_int64
);
int
device_id
=
static_cast
<
int
>
(
stack
[
sp
-
3
].
v_int64
);
size_t
nbytes
=
static_cast
<
size_t
>
(
stack
[
sp
].
v_int64
);
size_t
nbytes
=
static_cast
<
size_t
>
(
stack
[
sp
-
2
].
v_int64
);
void
*
ptr
=
TVMBackendAllocWorkspace
(
device_type
,
device_id
,
nbytes
);
int
dtype_code_hint
=
static_cast
<
int
>
(
stack
[
sp
-
1
].
v_int64
);
stack
[
sp
-
2
].
v_handle
=
ptr
;
int
dtype_bits_hint
=
static_cast
<
int
>
(
stack
[
sp
].
v_int64
);
sp
=
sp
-
2
;
void
*
ptr
=
TVMBackendAllocWorkspace
(
device_type
,
device_id
,
nbytes
,
dtype_code_hint
,
dtype_bits_hint
);
stack
[
sp
-
4
].
v_handle
=
ptr
;
sp
=
sp
-
4
;
pc
=
pc
+
1
;
pc
=
pc
+
1
;
break
;
break
;
}
}
...
...
src/pass/lower_tvm_builtin.cc
View file @
944de73b
...
@@ -96,12 +96,16 @@ class BuiltinLower : public IRMutator {
...
@@ -96,12 +96,16 @@ class BuiltinLower : public IRMutator {
{
op
->
buffer_var
},
Call
::
PureIntrinsic
),
{
op
->
buffer_var
},
Call
::
PureIntrinsic
),
throw_last_error
),
throw_last_error
),
op
->
body
);
op
->
body
);
Stmt
alloca
=
LetStmt
::
make
(
op
->
buffer_var
,
Stmt
alloca
=
LetStmt
::
make
(
op
->
buffer_var
,
Call
::
make
(
op
->
buffer_var
.
type
(),
Call
::
make
(
op
->
buffer_var
.
type
(),
"TVMBackendAllocWorkspace"
,
"TVMBackendAllocWorkspace"
,
{
cast
(
Int
(
32
),
device_type_
),
{
cast
(
Int
(
32
),
device_type_
),
cast
(
Int
(
32
),
device_id_
),
cast
(
Int
(
32
),
device_id_
),
cast
(
UInt
(
64
),
total_bytes
)},
cast
(
UInt
(
64
),
total_bytes
),
IntImm
::
make
(
Int
(
32
),
op
->
type
.
code
()),
IntImm
::
make
(
Int
(
32
),
op
->
type
.
bits
())},
Call
::
Extern
),
Call
::
Extern
),
body
);
body
);
...
...
src/pass/split_host_device.cc
View file @
944de73b
...
@@ -146,6 +146,11 @@ class IRUseDefAnalysis : public IRMutator {
...
@@ -146,6 +146,11 @@ class IRUseDefAnalysis : public IRMutator {
class
HostDeviceSplitter
:
public
IRMutator
{
class
HostDeviceSplitter
:
public
IRMutator
{
public
:
public
:
Stmt
Mutate_
(
const
Allocate
*
op
,
const
Stmt
&
s
)
final
{
handle_data_type_
[
op
->
buffer_var
.
get
()]
=
make_const
(
op
->
type
,
0
);
return
IRMutator
::
Mutate_
(
op
,
s
);
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
attr
::
thread_extent
||
if
(
op
->
attr_key
==
attr
::
thread_extent
||
op
->
attr_key
==
attr
::
pipeline_exec_scope
)
{
op
->
attr_key
==
attr
::
pipeline_exec_scope
)
{
...
...
src/runtime/c_runtime_api.cc
View file @
944de73b
...
@@ -95,8 +95,9 @@ DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
...
@@ -95,8 +95,9 @@ DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
static_cast
<
int
>
(
ctx
.
device_type
),
allow_missing
);
static_cast
<
int
>
(
ctx
.
device_type
),
allow_missing
);
}
}
void
*
DeviceAPI
::
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
{
void
*
DeviceAPI
::
AllocWorkspace
(
TVMContext
ctx
,
TVMType
type_hint
{
kDLUInt
,
8
,
1
};
size_t
size
,
TVMType
type_hint
)
{
return
AllocDataSpace
(
ctx
,
size
,
kTempAllocaAlignment
,
type_hint
);
return
AllocDataSpace
(
ctx
,
size
,
kTempAllocaAlignment
,
type_hint
);
}
}
...
@@ -221,11 +222,21 @@ int TVMBackendGetFuncFromEnv(void* mod_node,
...
@@ -221,11 +222,21 @@ int TVMBackendGetFuncFromEnv(void* mod_node,
void
*
TVMBackendAllocWorkspace
(
int
device_type
,
void
*
TVMBackendAllocWorkspace
(
int
device_type
,
int
device_id
,
int
device_id
,
uint64_t
size
)
{
uint64_t
size
,
int
dtype_code_hint
,
int
dtype_bits_hint
)
{
TVMContext
ctx
;
TVMContext
ctx
;
ctx
.
device_type
=
static_cast
<
DLDeviceType
>
(
device_type
);
ctx
.
device_type
=
static_cast
<
DLDeviceType
>
(
device_type
);
ctx
.
device_id
=
device_id
;
ctx
.
device_id
=
device_id
;
return
DeviceAPIManager
::
Get
(
ctx
)
->
AllocWorkspace
(
ctx
,
static_cast
<
size_t
>
(
size
));
TVMType
type_hint
;
type_hint
.
code
=
static_cast
<
decltype
(
type_hint
.
code
)
>
(
dtype_code_hint
);
type_hint
.
bits
=
static_cast
<
decltype
(
type_hint
.
bits
)
>
(
dtype_bits_hint
);
type_hint
.
lanes
=
1
;
return
DeviceAPIManager
::
Get
(
ctx
)
->
AllocWorkspace
(
ctx
,
static_cast
<
size_t
>
(
size
),
type_hint
);
}
}
int
TVMBackendFreeWorkspace
(
int
device_type
,
int
TVMBackendFreeWorkspace
(
int
device_type
,
...
...
src/runtime/cpu_device_api.cc
View file @
944de73b
...
@@ -59,7 +59,7 @@ class CPUDeviceAPI final : public DeviceAPI {
...
@@ -59,7 +59,7 @@ class CPUDeviceAPI final : public DeviceAPI {
void
StreamSync
(
TVMContext
ctx
,
TVMStreamHandle
stream
)
final
{
void
StreamSync
(
TVMContext
ctx
,
TVMStreamHandle
stream
)
final
{
}
}
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
final
;
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
,
TVMType
type_hint
)
final
;
void
FreeWorkspace
(
TVMContext
ctx
,
void
*
data
)
final
;
void
FreeWorkspace
(
TVMContext
ctx
,
void
*
data
)
final
;
static
const
std
::
shared_ptr
<
CPUDeviceAPI
>&
Global
()
{
static
const
std
::
shared_ptr
<
CPUDeviceAPI
>&
Global
()
{
...
@@ -74,7 +74,9 @@ struct CPUWorkspacePool : public WorkspacePool {
...
@@ -74,7 +74,9 @@ struct CPUWorkspacePool : public WorkspacePool {
WorkspacePool
(
kDLCPU
,
CPUDeviceAPI
::
Global
())
{}
WorkspacePool
(
kDLCPU
,
CPUDeviceAPI
::
Global
())
{}
};
};
void
*
CPUDeviceAPI
::
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
{
void
*
CPUDeviceAPI
::
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
,
TVMType
type_hint
)
{
return
dmlc
::
ThreadLocalStore
<
CPUWorkspacePool
>::
Get
()
return
dmlc
::
ThreadLocalStore
<
CPUWorkspacePool
>::
Get
()
->
AllocWorkspace
(
ctx
,
size
);
->
AllocWorkspace
(
ctx
,
size
);
}
}
...
...
src/runtime/cuda/cuda_device_api.cc
View file @
944de73b
...
@@ -112,7 +112,7 @@ class CUDADeviceAPI final : public DeviceAPI {
...
@@ -112,7 +112,7 @@ class CUDADeviceAPI final : public DeviceAPI {
->
stream
=
static_cast
<
cudaStream_t
>
(
stream
);
->
stream
=
static_cast
<
cudaStream_t
>
(
stream
);
}
}
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
final
{
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
,
TVMType
type_hint
)
final
{
return
CUDAThreadEntry
::
ThreadLocal
()
->
pool
.
AllocWorkspace
(
ctx
,
size
);
return
CUDAThreadEntry
::
ThreadLocal
()
->
pool
.
AllocWorkspace
(
ctx
,
size
);
}
}
...
...
src/runtime/metal/metal_common.h
View file @
944de73b
...
@@ -77,7 +77,7 @@ class MetalWorkspace final : public DeviceAPI {
...
@@ -77,7 +77,7 @@ class MetalWorkspace final : public DeviceAPI {
TVMContext
ctx_to
,
TVMContext
ctx_to
,
TVMStreamHandle
stream
)
final
;
TVMStreamHandle
stream
)
final
;
void
StreamSync
(
TVMContext
ctx
,
TVMStreamHandle
stream
)
final
;
void
StreamSync
(
TVMContext
ctx
,
TVMStreamHandle
stream
)
final
;
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
final
;
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
,
TVMType
type_hint
)
final
;
void
FreeWorkspace
(
TVMContext
ctx
,
void
*
data
)
final
;
void
FreeWorkspace
(
TVMContext
ctx
,
void
*
data
)
final
;
// get the global workspace
// get the global workspace
static
const
std
::
shared_ptr
<
MetalWorkspace
>&
Global
();
static
const
std
::
shared_ptr
<
MetalWorkspace
>&
Global
();
...
...
src/runtime/metal/metal_device_api.mm
View file @
944de73b
...
@@ -228,7 +228,9 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
...
@@ -228,7 +228,9 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
[cb waitUntilCompleted];
[cb waitUntilCompleted];
}
}
void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size) {
void* MetalWorkspace::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
}
}
...
...
src/runtime/opencl/opencl_common.h
View file @
944de73b
...
@@ -156,7 +156,7 @@ class OpenCLWorkspace final : public DeviceAPI {
...
@@ -156,7 +156,7 @@ class OpenCLWorkspace final : public DeviceAPI {
TVMContext
ctx_to
,
TVMContext
ctx_to
,
TVMStreamHandle
stream
)
final
;
TVMStreamHandle
stream
)
final
;
void
StreamSync
(
TVMContext
ctx
,
TVMStreamHandle
stream
)
final
;
void
StreamSync
(
TVMContext
ctx
,
TVMStreamHandle
stream
)
final
;
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
final
;
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
,
TVMType
type_hint
)
final
;
void
FreeWorkspace
(
TVMContext
ctx
,
void
*
data
)
final
;
void
FreeWorkspace
(
TVMContext
ctx
,
void
*
data
)
final
;
// get the global workspace
// get the global workspace
static
const
std
::
shared_ptr
<
OpenCLWorkspace
>&
Global
();
static
const
std
::
shared_ptr
<
OpenCLWorkspace
>&
Global
();
...
...
src/runtime/opencl/opencl_device_api.cc
View file @
944de73b
...
@@ -108,7 +108,9 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
...
@@ -108,7 +108,9 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
OPENCL_CALL
(
clFinish
(
this
->
GetQueue
(
ctx
)));
OPENCL_CALL
(
clFinish
(
this
->
GetQueue
(
ctx
)));
}
}
void
*
OpenCLWorkspace
::
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
{
void
*
OpenCLWorkspace
::
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
,
TVMType
type_hint
)
{
return
OpenCLThreadEntry
::
ThreadLocal
()
->
pool
.
AllocWorkspace
(
ctx
,
size
);
return
OpenCLThreadEntry
::
ThreadLocal
()
->
pool
.
AllocWorkspace
(
ctx
,
size
);
}
}
...
...
src/runtime/opengl/opengl_common.h
View file @
944de73b
...
@@ -175,8 +175,6 @@ class OpenGLWorkspace final : public DeviceAPI {
...
@@ -175,8 +175,6 @@ class OpenGLWorkspace final : public DeviceAPI {
TVMContext
ctx_to
,
TVMContext
ctx_to
,
TVMStreamHandle
stream
)
final
;
TVMStreamHandle
stream
)
final
;
void
StreamSync
(
TVMContext
ctx
,
TVMStreamHandle
stream
)
final
;
void
StreamSync
(
TVMContext
ctx
,
TVMStreamHandle
stream
)
final
;
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
final
;
void
FreeWorkspace
(
TVMContext
ctx
,
void
*
data
)
final
;
/*!
/*!
* \brief Get the global OpenGL workspace.
* \brief Get the global OpenGL workspace.
...
...
src/runtime/opengl/opengl_device_api.cc
View file @
944de73b
...
@@ -156,15 +156,6 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from,
...
@@ -156,15 +156,6 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from,
void
OpenGLWorkspace
::
StreamSync
(
TVMContext
ctx
,
TVMStreamHandle
stream
)
{}
void
OpenGLWorkspace
::
StreamSync
(
TVMContext
ctx
,
TVMStreamHandle
stream
)
{}
void
*
OpenGLWorkspace
::
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
{
LOG
(
FATAL
)
<<
"Cannot allocate OpenGL workspace."
;
return
nullptr
;
}
void
OpenGLWorkspace
::
FreeWorkspace
(
TVMContext
ctx
,
void
*
data
)
{
LOG
(
FATAL
)
<<
"Cannot free OpenGL workspace."
;
}
OpenGLWorkspace
::
OpenGLWorkspace
()
{
OpenGLWorkspace
::
OpenGLWorkspace
()
{
// Set an error handler.
// Set an error handler.
// This can be called before glfwInit().
// This can be called before glfwInit().
...
...
src/runtime/rocm/rocm_device_api.cc
View file @
944de73b
...
@@ -110,7 +110,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
...
@@ -110,7 +110,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
->
stream
=
static_cast
<
hipStream_t
>
(
stream
);
->
stream
=
static_cast
<
hipStream_t
>
(
stream
);
}
}
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
)
final
{
void
*
AllocWorkspace
(
TVMContext
ctx
,
size_t
size
,
TVMType
type_hint
)
final
{
return
ROCMThreadEntry
::
ThreadLocal
()
->
pool
.
AllocWorkspace
(
ctx
,
size
);
return
ROCMThreadEntry
::
ThreadLocal
()
->
pool
.
AllocWorkspace
(
ctx
,
size
);
}
}
...
...
tests/webgl/test_local_multi_stage.py
0 → 100644
View file @
944de73b
import
tvm
import
numpy
as
np
def
test_local_multi_stage
():
if
not
tvm
.
module
.
enabled
(
"opengl"
):
return
if
not
tvm
.
module
.
enabled
(
"llvm"
):
return
n
=
tvm
.
var
(
"n"
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
,
dtype
=
"int32"
)
B
=
tvm
.
compute
((
n
,),
lambda
i
:
A
[
i
]
+
1
,
name
=
"B"
)
C
=
tvm
.
compute
((
n
,),
lambda
i
:
B
[
i
]
*
2
,
name
=
"C"
)
s
=
tvm
.
create_schedule
(
C
.
op
)
s
[
B
]
.
opengl
()
s
[
C
]
.
opengl
()
f
=
tvm
.
build
(
s
,
[
A
,
C
],
"opengl"
,
name
=
"multi_stage"
)
ctx
=
tvm
.
opengl
(
0
)
n
=
10
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,))
.
astype
(
A
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,))
.
astype
(
B
.
dtype
),
ctx
)
f
(
a
,
c
)
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
(
a
.
asnumpy
()
+
1
)
*
2
)
if
__name__
==
"__main__"
:
test_local_multi_stage
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment