Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
163c4795
Commit
163c4795
authored
Oct 15, 2017
by
Tianqi Chen
Committed by
GitHub
Oct 15, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[CODEGEN] Bugfix multiple condition generation (#558)
parent
10faa893
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
75 additions
and
55 deletions
+75
-55
src/codegen/llvm/codegen_amdgpu.cc
+18
-16
src/codegen/llvm/codegen_llvm.cc
+4
-2
src/codegen/llvm/codegen_nvptx.cc
+25
-3
topi/python/topi/generic/injective.py
+1
-0
topi/python/topi/generic/nn.py
+1
-0
topi/recipe/gemm/cuda_gemm_square.py
+2
-2
topi/tests/python/test_topi_reduce.py
+2
-4
topi/tests/python/test_topi_softmax.py
+4
-3
topi/tests/python/test_topi_transform.py
+18
-25
No files found.
src/codegen/llvm/codegen_amdgpu.cc
View file @
163c4795
...
@@ -131,26 +131,29 @@ class CodeGenAMDGPU : public CodeGenLLVM {
...
@@ -131,26 +131,29 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}
}
};
};
runtime
::
Module
BuildAMDGPU
(
Array
<
LoweredFunc
>
funcs
,
std
::
string
target
)
{
inline
int
DetectROCMComputeVersion
()
{
CHECK
(
target
.
length
(
TVMContext
tvm_ctx
;
)
>=
4
&&
tvm_ctx
.
device_type
=
kROCM
;
target
.
substr
(
0
,
4
)
==
"rocm"
);
tvm_ctx
.
device_id
=
0
;
TVMContext
tvmCtx
;
tvmCtx
.
device_type
=
kROCM
;
tvmCtx
.
device_id
=
0
;
TVMRetValue
val
;
TVMRetValue
val
;
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvmCtx
)
->
GetAttr
(
tvmCtx
,
tvm
::
runtime
::
kExist
,
&
val
);
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
)
->
GetAttr
(
tvm_ctx
,
tvm
::
runtime
::
kExist
,
&
val
);
if
(
val
.
operator
int
()
==
1
)
{
if
(
val
.
operator
int
()
==
1
)
{
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvmCtx
)
->
GetAttr
(
tvmCtx
,
tvm
::
runtime
::
kComputeVersion
,
&
val
);
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
)
->
GetAttr
(
tvm_ctx
,
tvm
::
runtime
::
kComputeVersion
,
&
val
);
return
val
.
operator
int
();
}
else
{
}
else
{
val
=
803
;
return
803
;
}
}
}
llvm
::
TargetMachine
*
tm
=
\
runtime
::
Module
BuildAMDGPU
(
Array
<
LoweredFunc
>
funcs
,
std
::
string
target
)
{
GetLLVMTargetMachine
(
"-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx"
+
\
CHECK
(
target
.
length
()
>=
4
&&
std
::
to_string
(
val
.
operator
int
())
+
target
.
substr
(
4
,
target
.
length
()
-
4
));
target
.
substr
(
0
,
4
)
==
"rocm"
);
std
::
ostringstream
config
;
config
<<
"-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx"
<<
DetectROCMComputeVersion
()
<<
target
.
substr
(
4
,
target
.
length
()
-
4
);
llvm
::
TargetMachine
*
tm
=
GetLLVMTargetMachine
(
config
.
str
());
std
::
unique_ptr
<
CodeGenAMDGPU
>
cg
(
new
CodeGenAMDGPU
());
std
::
unique_ptr
<
CodeGenAMDGPU
>
cg
(
new
CodeGenAMDGPU
());
std
::
unique_ptr
<
llvm
::
LLVMContext
>
ctx
(
new
llvm
::
LLVMContext
());
std
::
unique_ptr
<
llvm
::
LLVMContext
>
ctx
(
new
llvm
::
LLVMContext
());
cg
->
Init
(
funcs
[
0
]
->
name
,
tm
,
ctx
.
get
(),
false
,
false
);
cg
->
Init
(
funcs
[
0
]
->
name
,
tm
,
ctx
.
get
(),
false
,
false
);
...
@@ -159,7 +162,6 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
...
@@ -159,7 +162,6 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
}
}
std
::
unique_ptr
<
llvm
::
Module
>
module
=
cg
->
Finish
();
std
::
unique_ptr
<
llvm
::
Module
>
module
=
cg
->
Finish
();
llvm
::
SmallString
<
8
>
dataObj
,
data_ll
,
dataAsm
;
llvm
::
SmallString
<
8
>
dataObj
,
data_ll
,
dataAsm
;
llvm
::
raw_svector_ostream
destObj
(
dataObj
),
dest_ll
(
data_ll
),
destAsm
(
dataAsm
);
llvm
::
raw_svector_ostream
destObj
(
dataObj
),
dest_ll
(
data_ll
),
destAsm
(
dataAsm
);
destObj
.
SetUnbuffered
();
destObj
.
SetUnbuffered
();
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
163c4795
...
@@ -582,14 +582,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
...
@@ -582,14 +582,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
builder_
->
CreateCondBr
(
MakeValue
(
op
->
args
[
0
]),
then_block
,
else_block
);
builder_
->
CreateCondBr
(
MakeValue
(
op
->
args
[
0
]),
then_block
,
else_block
);
builder_
->
SetInsertPoint
(
then_block
);
builder_
->
SetInsertPoint
(
then_block
);
llvm
::
Value
*
then_value
=
MakeValue
(
op
->
args
[
1
]);
llvm
::
Value
*
then_value
=
MakeValue
(
op
->
args
[
1
]);
BasicBlock
*
then_value_block
=
builder_
->
GetInsertBlock
();
builder_
->
CreateBr
(
end_block
);
builder_
->
CreateBr
(
end_block
);
builder_
->
SetInsertPoint
(
else_block
);
builder_
->
SetInsertPoint
(
else_block
);
llvm
::
Value
*
else_value
=
MakeValue
(
op
->
args
[
2
]);
llvm
::
Value
*
else_value
=
MakeValue
(
op
->
args
[
2
]);
BasicBlock
*
else_value_block
=
builder_
->
GetInsertBlock
();
builder_
->
CreateBr
(
end_block
);
builder_
->
CreateBr
(
end_block
);
builder_
->
SetInsertPoint
(
end_block
);
builder_
->
SetInsertPoint
(
end_block
);
llvm
::
PHINode
*
value
=
builder_
->
CreatePHI
(
then_value
->
getType
(),
2
);
llvm
::
PHINode
*
value
=
builder_
->
CreatePHI
(
then_value
->
getType
(),
2
);
value
->
addIncoming
(
then_value
,
then_block
);
value
->
addIncoming
(
then_value
,
then_
value_
block
);
value
->
addIncoming
(
else_value
,
else_block
);
value
->
addIncoming
(
else_value
,
else_
value_
block
);
return
value
;
return
value
;
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"unknown intrinsic "
<<
op
->
name
;
LOG
(
FATAL
)
<<
"unknown intrinsic "
<<
op
->
name
;
...
...
src/codegen/llvm/codegen_nvptx.cc
View file @
163c4795
...
@@ -130,12 +130,34 @@ class CodeGenNVPTX : public CodeGenLLVM {
...
@@ -130,12 +130,34 @@ class CodeGenNVPTX : public CodeGenLLVM {
}
}
};
};
inline
int
DetectCUDAComputeVersion
()
{
TVMContext
tvm_ctx
;
tvm_ctx
.
device_type
=
kGPU
;
tvm_ctx
.
device_id
=
0
;
TVMRetValue
val
;
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
)
->
GetAttr
(
tvm_ctx
,
tvm
::
runtime
::
kExist
,
&
val
);
if
(
val
.
operator
int
()
==
1
)
{
tvm
::
runtime
::
DeviceAPI
::
Get
(
tvm_ctx
)
->
GetAttr
(
tvm_ctx
,
tvm
::
runtime
::
kComputeVersion
,
&
val
);
std
::
string
version
=
val
;
std
::
istringstream
is
(
version
);
double
ver
;
is
>>
ver
;
return
static_cast
<
int
>
(
ver
*
10
);
}
else
{
return
20
;
}
}
runtime
::
Module
BuildNVPTX
(
Array
<
LoweredFunc
>
funcs
,
std
::
string
target
)
{
runtime
::
Module
BuildNVPTX
(
Array
<
LoweredFunc
>
funcs
,
std
::
string
target
)
{
CHECK
(
target
.
length
()
>=
5
&&
CHECK
(
target
.
length
()
>=
5
&&
target
.
substr
(
0
,
5
)
==
"nvptx"
);
target
.
substr
(
0
,
5
)
==
"nvptx"
);
llvm
::
TargetMachine
*
tm
=
GetLLVMTargetMachine
(
std
::
ostringstream
config
;
"-mtriple=nvptx64-nvidia-cuda -mcpu=sm_20"
+
config
<<
"-mtriple=nvptx64-nvidia-cuda -mcpu=sm_"
target
.
substr
(
5
,
target
.
length
()
-
5
));
<<
DetectCUDAComputeVersion
()
<<
target
.
substr
(
5
,
target
.
length
()
-
5
);
llvm
::
TargetMachine
*
tm
=
GetLLVMTargetMachine
(
config
.
str
());
std
::
unique_ptr
<
CodeGenNVPTX
>
cg
(
new
CodeGenNVPTX
());
std
::
unique_ptr
<
CodeGenNVPTX
>
cg
(
new
CodeGenNVPTX
());
std
::
unique_ptr
<
llvm
::
LLVMContext
>
ctx
(
new
llvm
::
LLVMContext
());
std
::
unique_ptr
<
llvm
::
LLVMContext
>
ctx
(
new
llvm
::
LLVMContext
());
cg
->
Init
(
funcs
[
0
]
->
name
,
tm
,
ctx
.
get
(),
false
,
false
);
cg
->
Init
(
funcs
[
0
]
->
name
,
tm
,
ctx
.
get
(),
false
,
false
);
...
...
topi/python/topi/generic/injective.py
View file @
163c4795
...
@@ -22,6 +22,7 @@ def schedule_injective(outs):
...
@@ -22,6 +22,7 @@ def schedule_injective(outs):
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
if
target
.
target_name
!=
"llvm"
:
if
target
.
target_name
!=
"llvm"
:
raise
RuntimeError
(
"schedule_injective not registered for '
%
s'"
%
target
)
raise
RuntimeError
(
"schedule_injective not registered for '
%
s'"
%
target
)
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
x
=
outs
[
0
]
x
=
outs
[
0
]
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
tvm
.
schedule
.
AutoInlineInjective
(
s
)
tvm
.
schedule
.
AutoInlineInjective
(
s
)
...
...
topi/python/topi/generic/nn.py
View file @
163c4795
...
@@ -6,6 +6,7 @@ import tvm
...
@@ -6,6 +6,7 @@ import tvm
def
_default_schedule
(
outs
,
auto_inline
):
def
_default_schedule
(
outs
,
auto_inline
):
"""Default schedule for llvm."""
"""Default schedule for llvm."""
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
if
target
.
target_name
!=
"llvm"
:
if
target
.
target_name
!=
"llvm"
:
raise
RuntimeError
(
"schedule_pool not registered for '
%
s'"
%
target
)
raise
RuntimeError
(
"schedule_pool not registered for '
%
s'"
%
target
)
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
...
...
topi/recipe/gemm/cuda_gemm_square.py
View file @
163c4795
...
@@ -125,10 +125,10 @@ def test_gemm():
...
@@ -125,10 +125,10 @@ def test_gemm():
GFLOPS
=
num_flops
/
(
t
*
1e3
)
/
1e6
GFLOPS
=
num_flops
/
(
t
*
1e3
)
/
1e6
print
(
"average time cost of
%
d runs =
%
g ms,
%
g GFLOPS."
%
(
num_runs
,
t
*
1e3
,
GFLOPS
))
print
(
"average time cost of
%
d runs =
%
g ms,
%
g GFLOPS."
%
(
num_runs
,
t
*
1e3
,
GFLOPS
))
for
device
in
[
'cuda'
,
'opencl'
,
'rocm'
]:
for
device
in
[
"cuda"
,
"opencl"
,
"rocm"
]:
with
tvm
.
build_config
(
auto_unroll_max_step
=
32
,
with
tvm
.
build_config
(
auto_unroll_max_step
=
32
,
auto_unroll_min_depth
=
0
,
auto_unroll_min_depth
=
0
,
unroll_explicit
=
device
==
'rocm'
):
unroll_explicit
=
(
device
!=
"cuda"
)
):
check_device
(
device
)
check_device
(
device
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
topi/tests/python/test_topi_reduce.py
View file @
163c4795
...
@@ -74,11 +74,9 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
...
@@ -74,11 +74,9 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
for
_
in
range
(
1
):
for
_
in
range
(
1
):
foo
(
data_tvm
,
out_tvm
)
foo
(
data_tvm
,
out_tvm
)
np
.
testing
.
assert_allclose
(
out_tvm
.
asnumpy
(),
out_npy
,
1E-3
,
1E-3
)
np
.
testing
.
assert_allclose
(
out_tvm
.
asnumpy
(),
out_npy
,
1E-3
,
1E-3
)
for
device
in
[
"cuda"
,
"opencl"
,
"metal"
,
"llvm"
,
"rocm"
]:
check_device
(
device
)
check_device
(
"opencl"
)
check_device
(
"cuda"
)
check_device
(
"metal"
)
check_device
(
"rocm"
)
def
test_reduce_map
():
def
test_reduce_map
():
verify_reduce_map_ele
(
in_shape
=
(
128
,
24
,
128
,
24
),
verify_reduce_map_ele
(
in_shape
=
(
128
,
24
,
128
,
24
),
...
...
topi/tests/python/test_topi_softmax.py
View file @
163c4795
...
@@ -3,6 +3,7 @@ import os
...
@@ -3,6 +3,7 @@ import os
import
numpy
as
np
import
numpy
as
np
import
tvm
import
tvm
import
topi
import
topi
import
logging
from
topi.util
import
get_const_tuple
from
topi.util
import
get_const_tuple
def
verify_softmax
(
m
,
n
):
def
verify_softmax
(
m
,
n
):
...
@@ -42,8 +43,6 @@ def verify_log_softmax(m, n):
...
@@ -42,8 +43,6 @@ def verify_log_softmax(m, n):
# confirm lower works
# confirm lower works
s
=
tvm
.
create_schedule
([
B
.
op
])
s
=
tvm
.
create_schedule
([
B
.
op
])
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
True
)
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
True
)
a_np
=
np
.
random
.
uniform
(
size
=
get_const_tuple
(
A
.
shape
))
.
astype
(
A
.
dtype
)
a_np
=
np
.
random
.
uniform
(
size
=
get_const_tuple
(
A
.
shape
))
.
astype
(
A
.
dtype
)
b_np
=
topi
.
testing
.
log_softmax_python
(
a_np
)
b_np
=
topi
.
testing
.
log_softmax_python
(
a_np
)
...
@@ -60,13 +59,15 @@ def verify_log_softmax(m, n):
...
@@ -60,13 +59,15 @@ def verify_log_softmax(m, n):
foo
(
a
,
b
)
foo
(
a
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
b_np
,
rtol
=
1e-5
)
for
device
in
[
'cuda'
,
'opencl'
,
'metal'
,
'rocm'
]:
for
device
in
[
"cuda"
,
"opencl"
,
"metal"
,
"rocm"
]:
check_device
(
device
)
check_device
(
device
)
def
test_log_softmax
():
def
test_log_softmax
():
verify_log_softmax
(
32
,
10
)
verify_log_softmax
(
32
,
10
)
verify_log_softmax
(
3
,
4
)
verify_log_softmax
(
3
,
4
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
test_softmax
()
test_softmax
()
test_log_softmax
()
test_log_softmax
()
topi/tests/python/test_topi_transform.py
View file @
163c4795
...
@@ -21,10 +21,8 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
...
@@ -21,10 +21,8 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
foo
(
data_nd
,
out_nd
)
foo
(
data_nd
,
out_nd
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
check_device
(
"opencl"
)
for
device
in
[
"llvm"
,
"nvptx"
,
"cuda"
,
"opencl"
,
"metal"
,
"rocm"
]:
check_device
(
"cuda"
)
check_device
(
device
)
check_device
(
"metal"
)
check_device
(
"rocm"
)
def
verify_tranpose
(
in_shape
,
axes
):
def
verify_tranpose
(
in_shape
,
axes
):
...
@@ -45,10 +43,9 @@ def verify_tranpose(in_shape, axes):
...
@@ -45,10 +43,9 @@ def verify_tranpose(in_shape, axes):
foo
(
data_nd
,
out_nd
)
foo
(
data_nd
,
out_nd
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
check_device
(
"cuda"
)
for
device
in
[
"llvm"
,
"nvptx"
,
"cuda"
,
"opencl"
,
"metal"
,
"rocm"
]:
check_device
(
"opencl"
)
check_device
(
device
)
check_device
(
"metal"
)
check_device
(
"rocm"
)
def
verify_reshape
(
src_shape
,
dst_shape
):
def
verify_reshape
(
src_shape
,
dst_shape
):
A
=
tvm
.
placeholder
(
shape
=
src_shape
,
name
=
"A"
)
A
=
tvm
.
placeholder
(
shape
=
src_shape
,
name
=
"A"
)
...
@@ -68,10 +65,9 @@ def verify_reshape(src_shape, dst_shape):
...
@@ -68,10 +65,9 @@ def verify_reshape(src_shape, dst_shape):
foo
(
data_nd
,
out_nd
)
foo
(
data_nd
,
out_nd
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
check_device
(
"cuda"
)
for
device
in
[
"llvm"
,
"nvptx"
,
"cuda"
,
"opencl"
,
"metal"
,
"rocm"
]:
check_device
(
"opencl"
)
check_device
(
device
)
check_device
(
"metal"
)
check_device
(
"rocm"
)
def
verify_squeeze
(
src_shape
,
axis
):
def
verify_squeeze
(
src_shape
,
axis
):
A
=
tvm
.
placeholder
(
shape
=
src_shape
,
name
=
"A"
)
A
=
tvm
.
placeholder
(
shape
=
src_shape
,
name
=
"A"
)
...
@@ -95,10 +91,8 @@ def verify_squeeze(src_shape, axis):
...
@@ -95,10 +91,8 @@ def verify_squeeze(src_shape, axis):
foo
(
data_nd
,
out_nd
)
foo
(
data_nd
,
out_nd
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
check_device
(
"cuda"
)
for
device
in
[
"llvm"
,
"nvptx"
,
"cuda"
,
"opencl"
,
"metal"
,
"rocm"
]:
check_device
(
"opencl"
)
check_device
(
device
)
check_device
(
"metal"
)
check_device
(
"rocm"
)
def
verify_concatenate
(
shapes
,
axis
):
def
verify_concatenate
(
shapes
,
axis
):
tensor_l
=
[]
tensor_l
=
[]
...
@@ -120,10 +114,9 @@ def verify_concatenate(shapes, axis):
...
@@ -120,10 +114,9 @@ def verify_concatenate(shapes, axis):
foo
(
*
(
data_nds
+
[
out_nd
]))
foo
(
*
(
data_nds
+
[
out_nd
]))
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
check_device
(
"cuda"
)
for
device
in
[
"llvm"
,
"nvptx"
,
"cuda"
,
"opencl"
,
"metal"
,
"rocm"
]:
check_device
(
"opencl"
)
check_device
(
device
)
check_device
(
"metal"
)
check_device
(
"rocm"
)
def
verify_split
(
src_shape
,
indices_or_sections
,
axis
):
def
verify_split
(
src_shape
,
indices_or_sections
,
axis
):
A
=
tvm
.
placeholder
(
shape
=
src_shape
,
name
=
"A"
)
A
=
tvm
.
placeholder
(
shape
=
src_shape
,
name
=
"A"
)
...
@@ -144,10 +137,9 @@ def verify_split(src_shape, indices_or_sections, axis):
...
@@ -144,10 +137,9 @@ def verify_split(src_shape, indices_or_sections, axis):
for
out_nd
,
out_npy
in
zip
(
out_nds
,
out_npys
):
for
out_nd
,
out_npy
in
zip
(
out_nds
,
out_npys
):
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
check_device
(
"cuda"
)
for
device
in
[
"llvm"
,
"nvptx"
,
"cuda"
,
"opencl"
,
"metal"
,
"rocm"
]:
check_device
(
"opencl"
)
check_device
(
device
)
check_device
(
"metal"
)
check_device
(
"rocm"
)
def
test_expand_dims
():
def
test_expand_dims
():
verify_expand_dims
((
3
,
10
),
(
3
,
10
,
1
,
1
),
2
,
2
)
verify_expand_dims
((
3
,
10
),
(
3
,
10
,
1
,
1
),
2
,
2
)
...
@@ -175,6 +167,7 @@ def test_squeeze():
...
@@ -175,6 +167,7 @@ def test_squeeze():
def
test_concatenate
():
def
test_concatenate
():
verify_concatenate
([(
2
,),
(
2
,),
(
2
,)],
0
)
verify_concatenate
([(
2
,
3
,
4
),
(
2
,
2
,
4
),
(
2
,
5
,
4
)],
1
)
verify_concatenate
([(
2
,
3
,
4
),
(
2
,
2
,
4
),
(
2
,
5
,
4
)],
1
)
verify_concatenate
([(
1
,
2
,
4
),
(
1
,
2
,
3
),
(
1
,
2
,
7
),
(
1
,
2
,
8
),
(
1
,
2
,
1
)],
-
1
)
verify_concatenate
([(
1
,
2
,
4
),
(
1
,
2
,
3
),
(
1
,
2
,
7
),
(
1
,
2
,
8
),
(
1
,
2
,
1
)],
-
1
)
verify_concatenate
([(
5
,
6
,
7
,
3
),
verify_concatenate
([(
5
,
6
,
7
,
3
),
...
@@ -190,9 +183,9 @@ def test_split():
...
@@ -190,9 +183,9 @@ def test_split():
verify_split
((
10
,
12
,
24
),
[
5
,
7
,
9
],
-
1
)
verify_split
((
10
,
12
,
24
),
[
5
,
7
,
9
],
-
1
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_concatenate
()
test_tranpose
()
test_tranpose
()
test_expand_dims
()
test_expand_dims
()
test_reshape
()
test_reshape
()
test_squeeze
()
test_squeeze
()
test_concatenate
()
test_split
()
test_split
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment