Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
af9f69a7
Commit
af9f69a7
authored
Jan 12, 2018
by
Yuwei Hu
Committed by
Tianqi Chen
Jan 11, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[INTRIN] enable popcount on cuda, opencl, metal (#774)
parent
e4a51303
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
69 additions
and
36 deletions
+69
-36
src/codegen/intrin_rule.h
+4
-8
src/codegen/intrin_rule_cuda.cc
+16
-0
src/codegen/intrin_rule_metal.cc
+8
-5
src/codegen/intrin_rule_opencl.cc
+8
-5
tests/python/integration/test_ewise.py
+33
-18
No files found.
src/codegen/intrin_rule.h
View file @
af9f69a7
...
...
@@ -30,18 +30,14 @@ struct FloatSuffix {
}
};
//
Add float suffix to the intrinsics
struct
Float
Direct
{
//
Return the intrinsic name
struct
Direct
{
std
::
string
operator
()(
Type
t
,
std
::
string
name
)
const
{
if
(
t
.
is_float
())
{
return
name
;
}
else
{
return
""
;
}
return
name
;
}
};
//
Directly call pure extern function for floats
.
//
Call pure extern function
.
template
<
typename
T
>
inline
void
DispatchExtern
(
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
Expr
e
=
args
[
0
];
...
...
src/codegen/intrin_rule_cuda.cc
View file @
af9f69a7
...
...
@@ -36,6 +36,19 @@ struct CUDAFastMath : public CUDAMath {
}
};
struct
CUDAPopcount
{
std
::
string
operator
()(
Type
t
,
std
::
string
name
)
const
{
if
(
t
.
lanes
()
==
1
&&
t
.
is_uint
())
{
switch
(
t
.
bits
())
{
case
32
:
return
"__popc"
;
case
64
:
return
"__popcll"
;
default
:
return
""
;
}
}
return
""
;
}
};
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.cuda.exp"
)
.
set_body
(
DispatchExtern
<
CUDAFastMath
>
);
...
...
@@ -51,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt")
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.cuda.pow"
)
.
set_body
(
DispatchExtern
<
CUDAMath
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.cuda.popcount"
)
.
set_body
(
DispatchExtern
<
CUDAPopcount
>
);
}
// namespace intrin
}
// namespace codegen
}
// namespace tvm
src/codegen/intrin_rule_metal.cc
View file @
af9f69a7
...
...
@@ -10,19 +10,22 @@ namespace codegen {
namespace
intrin
{
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.metal.exp"
)
.
set_body
(
DispatchExtern
<
Float
Direct
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.metal.log"
)
.
set_body
(
DispatchExtern
<
Float
Direct
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.metal.tanh"
)
.
set_body
(
DispatchExtern
<
Float
Direct
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.metal.sqrt"
)
.
set_body
(
DispatchExtern
<
Float
Direct
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.metal.pow"
)
.
set_body
(
DispatchExtern
<
FloatDirect
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.metal.popcount"
)
.
set_body
(
DispatchExtern
<
Direct
>
);
}
// namespace intrin
}
// namespace codegen
...
...
src/codegen/intrin_rule_opencl.cc
View file @
af9f69a7
...
...
@@ -10,19 +10,22 @@ namespace codegen {
namespace
intrin
{
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.opencl.exp"
)
.
set_body
(
DispatchExtern
<
Float
Direct
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.opencl.log"
)
.
set_body
(
DispatchExtern
<
Float
Direct
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.opencl.tanh"
)
.
set_body
(
DispatchExtern
<
Float
Direct
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.opencl.sqrt"
)
.
set_body
(
DispatchExtern
<
Float
Direct
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.opencl.pow"
)
.
set_body
(
DispatchExtern
<
FloatDirect
>
);
.
set_body
(
DispatchExtern
<
Direct
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.opencl.popcount"
)
.
set_body
(
DispatchExtern
<
Direct
>
);
}
// namespace intrin
}
// namespace codegen
...
...
tests/python/integration/test_ewise.py
View file @
af9f69a7
...
...
@@ -60,25 +60,40 @@ def test_log_pow_llvm():
b
.
asnumpy
(),
np
.
power
(
np
.
log
(
a
.
asnumpy
()),
2.0
),
rtol
=
1e-5
)
def
test_popcount_llvm
():
# graph
n
=
tvm
.
var
(
'n'
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
,
dtype
=
"uint32"
)
B
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
tvm
.
popcount
(
A
(
*
i
)),
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
def
test_popcount
():
def
run
(
dtype
):
# graph
n
=
tvm
.
convert
(
1024
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
,
dtype
=
dtype
)
B
=
tvm
.
compute
(
A
.
shape
,
lambda
*
i
:
tvm
.
popcount
(
A
(
*
i
)),
name
=
'B'
)
s
=
tvm
.
create_schedule
(
B
.
op
)
# simple schedule
num_thread
=
8
bx
,
tx
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
num_thread
)
if
not
tvm
.
module
.
enabled
(
"llvm"
):
return
f
=
tvm
.
build
(
s
,
[
A
,
B
],
"llvm"
)
ctx
=
tvm
.
cpu
(
0
)
# launch the kernel.
n
=
1024
a
=
tvm
.
nd
.
array
(
np
.
random
.
randint
(
low
=
0
,
high
=
1000
,
size
=
n
,
dtype
=
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
shape
=
n
,
dtype
=
B
.
dtype
),
ctx
)
f
(
a
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
list
(
map
(
lambda
x
:
bin
(
x
)
.
count
(
'1'
),
a
.
asnumpy
())),
rtol
=
1e-5
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
print
(
"skip because
%
s is not enabled.."
%
device
)
return
ctx
=
tvm
.
context
(
device
,
0
)
if
str
(
ctx
)
.
startswith
(
'gpu'
):
s
[
B
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
B
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
func
=
tvm
.
build
(
s
,
[
A
,
B
],
device
)
# launch the kernel.
n
=
1024
a
=
tvm
.
nd
.
array
(
np
.
random
.
randint
(
low
=
0
,
high
=
1000
,
size
=
n
,
dtype
=
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
zeros
(
shape
=
n
,
dtype
=
B
.
dtype
),
ctx
)
func
(
a
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
list
(
map
(
lambda
x
:
bin
(
x
)
.
count
(
'1'
),
a
.
asnumpy
())),
rtol
=
1e-5
)
check_device
(
"llvm"
)
check_device
(
"cuda"
)
check_device
(
"opencl"
)
check_device
(
"metal"
)
run
(
'uint32'
)
run
(
'uint64'
)
def
test_add
():
...
...
@@ -133,5 +148,5 @@ def test_add():
if
__name__
==
"__main__"
:
test_add
()
test_log_pow_llvm
()
test_popcount
_llvm
()
test_popcount
()
test_exp
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment