Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
bf8a5c07
Commit
bf8a5c07
authored
May 08, 2017
by
Tianqi Chen
Committed by
GitHub
May 08, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SCHEDULE] Add store_predicate (#131)
parent
2112a1f9
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
271 additions
and
14 deletions
+271
-14
Makefile
+1
-1
include/tvm/schedule.h
+19
-0
make/config.mk
+3
-3
python/tvm/build.py
+1
-1
python/tvm/expr.py
+16
-1
python/tvm/schedule.py
+13
-0
src/api/api_lang.cc
+6
-0
src/codegen/codegen_cuda.h
+1
-1
src/codegen/intrin_rule_cuda.cc
+14
-3
src/op/compute_op.cc
+8
-2
src/schedule/schedule_lang.cc
+6
-0
tests/python/integration/test_reduce.py
+3
-1
tests/python/perf/lstm.py
+178
-0
tests/python/perf/rnn_matexp.py
+2
-1
No files found.
Makefile
View file @
bf8a5c07
...
...
@@ -76,7 +76,7 @@ else
endif
# llvm configuration
if
eq
($(USE_LLVM),
1)
if
def
LLVM_CONFIG
LLVM_VERSION
=
$(
shell
$(LLVM_CONFIG)
--version
| cut
-b
1,3
)
LLVM_INCLUDE
=
$
(
filter
-I
%,
$(
shell
$(LLVM_CONFIG)
--cxxflags
)
)
LDFLAGS
+=
$(
shell
$(LLVM_CONFIG)
--ldflags
--libs
--system-libs
)
...
...
include/tvm/schedule.h
View file @
bf8a5c07
...
...
@@ -82,10 +82,22 @@ class Stage : public NodeRef {
*/
Stage
&
bind
(
IterVar
ivar
,
IterVar
thread_ivar
);
/*!
* \brief Set predicate under which store to the array can be performed.
* Use this when there are duplicated threads doing the same store and we only
* need one of them to do the store.
*
* \note This is a dangerous scheduling primitive that can change behavior of program.
* Only do when we are certain that thare are duplicated store.
* \param predicate The condition to be checked.
* \return reference to self.
*/
Stage
&
set_store_predicate
(
Expr
predicate
);
/*!
* \brief Specify environment threads that launched around the group's scope.
* This can only be used in group stage.
* \param threads The threads to be launched around the scope.
* \note Each thread can only appear in one env_threads.
* This is a beta feature.
* \return reference to self.
*/
Stage
&
env_threads
(
Array
<
IterVar
>
threads
);
...
...
@@ -341,8 +353,15 @@ class StageNode : public Node {
/*!
* \brief Specify threads to be launched at the stage.
* This is only valid for composite ops such as Scan.
* \note Experimental primitive: used for thread persistence.
*/
Array
<
IterVar
>
env_threads
;
/*!
* \brief The predicate under which store can happen
* Use this when there can be duplicated threads doing the same store.
* \note Experimental primitive: used by cross thread-reduction.
*/
Expr
store_predicate
;
/*! \brief The relation bwteen of IterVars */
Array
<
IterVarRelation
>
relations
;
/*! \brief additional attributes about iter var. */
...
...
make/config.mk
View file @
bf8a5c07
...
...
@@ -39,9 +39,9 @@ USE_METAL = 0
# whether build with LLVM support
# Requires LLVM version >= 4.0
# Set LLVM_CONFIG to your version
#
LLVM_CONFIG = llvm-config-4.0
USE_LLVM = 0
# Set LLVM_CONFIG to your version
, uncomment to build with llvm support
#
# LLVM_CONFIG = llvm-config
#---------------------------------------------
# Contrib optional libraries.
...
...
python/tvm/build.py
View file @
bf8a5c07
...
...
@@ -85,7 +85,7 @@ def build(sch,
target_host
=
None
,
name
=
"default_function"
,
binds
=
None
,
max_auto_unroll_step
=
8
,
max_auto_unroll_step
=
0
,
detect_global_barrier
=
True
):
"""Build a function with arguments as signiture.
...
...
python/tvm/expr.py
View file @
bf8a5c07
...
...
@@ -63,7 +63,7 @@ class ExprOp(object):
return
_make
.
LE
(
self
,
other
)
def
__eq__
(
self
,
other
):
return
_make
.
EQ
(
self
,
other
)
return
self
.
equal
(
other
)
def
__ne__
(
self
,
other
):
return
_make
.
NE
(
self
,
other
)
...
...
@@ -74,6 +74,21 @@ class ExprOp(object):
def
__ge__
(
self
,
other
):
return
_make
.
GE
(
self
,
other
)
def
equal
(
self
,
other
):
"""Build an equal check expression with other expr.
Parameters
----------
other : Expr
The other expression
Returns
-------
ret : Expr
The equality expression.
"""
return
_make
.
EQ
(
self
,
other
)
class
Expr
(
NodeBase
,
ExprOp
):
"""Base class of all tvm Expressions"""
...
...
python/tvm/schedule.py
View file @
bf8a5c07
...
...
@@ -276,6 +276,19 @@ class Stage(NodeBase):
threads
=
[
threads
]
_api_internal
.
_StageEnvThreads
(
self
,
threads
)
def
set_store_predicate
(
self
,
predicate
):
"""Set predicate under which store to the array can be performed.
Use this when there are duplicated threads doing the same store and we only
need one of them to do the store.
Parameters
----------
predicate : Expr
The guard condition fo store.
"""
_api_internal
.
_StageSetStorePredicate
(
self
,
predicate
)
def
compute_at
(
self
,
parent
,
scope
):
"""Attach the stage at parent's scope
...
...
src/api/api_lang.cc
View file @
bf8a5c07
...
...
@@ -307,6 +307,12 @@ TVM_REGISTER_API("_StageEnvThreads")
.
env_threads
(
args
[
1
]);
});
TVM_REGISTER_API
(
"_StageSetStorePredicate"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
.
set_store_predicate
(
args
[
1
]);
});
TVM_REGISTER_API
(
"_StageUnroll"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
args
[
0
].
operator
Stage
()
...
...
src/codegen/codegen_cuda.h
View file @
bf8a5c07
...
...
@@ -38,7 +38,7 @@ class CodeGenCUDA final : public CodeGenC {
private
:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int
max_auto_unroll_
{
64
};
int
max_auto_unroll_
{
256
};
// Whether global barrier is needed.
bool
need_global_barrier_
{
false
};
// Global barrier state
...
...
src/codegen/intrin_rule_cuda.cc
View file @
bf8a5c07
...
...
@@ -9,13 +9,13 @@ namespace tvm {
namespace
codegen
{
namespace
intrin
{
// Add float suffix to the intrinsics, CUDA fast math.
struct
CUDA
Fast
Math
{
struct
CUDAMath
{
std
::
string
operator
()(
Type
t
,
std
::
string
name
)
const
{
if
(
t
.
lanes
()
==
1
)
{
if
(
t
.
is_float
())
{
switch
(
t
.
bits
())
{
case
64
:
return
name
;
case
32
:
return
"__"
+
name
+
'f'
;
case
32
:
return
name
+
'f'
;
case
16
:
return
'h'
+
name
;
default
:
return
""
;
}
...
...
@@ -25,6 +25,17 @@ struct CUDAFastMath {
}
};
struct
CUDAFastMath
:
public
CUDAMath
{
std
::
string
operator
()(
Type
t
,
std
::
string
name
)
const
{
if
(
t
.
lanes
()
==
1
&&
t
.
is_float
()
&&
t
.
bits
()
==
32
)
{
return
"__"
+
name
+
'f'
;
}
else
{
return
CUDAMath
::
operator
()(
t
,
name
);
}
return
""
;
}
};
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.cuda.exp"
)
.
set_body
(
DispatchExtern
<
CUDAFastMath
>
);
...
...
@@ -32,7 +43,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.
set_body
(
DispatchExtern
<
CUDAFastMath
>
);
TVM_REGISTER_GLOBAL
(
"tvm.intrin.rule.cuda.tanh"
)
.
set_body
(
DispatchExtern
<
CUDA
Fast
Math
>
);
.
set_body
(
DispatchExtern
<
CUDAMath
>
);
}
// namespace intrin
}
// namespace codegen
...
...
src/op/compute_op.cc
View file @
bf8a5c07
...
...
@@ -242,7 +242,6 @@ Stmt MakeCrossThreadReduction(
freduce_args
.
push_back
(
reduce
->
source
);
freduce_args
.
push_back
(
cond
);
std
::
vector
<
Expr
>
thread_head_check
;
for
(
IterVar
iv
:
stage
->
leaf_iter_vars
)
{
if
(
iv
->
iter_type
==
kCommReduce
)
{
auto
it
=
stage
->
iter_var_attrs
.
find
(
iv
);
...
...
@@ -250,10 +249,14 @@ Stmt MakeCrossThreadReduction(
(
*
it
).
second
->
bind_thread
.
defined
())
{
IterVar
tv
=
(
*
it
).
second
->
bind_thread
;
freduce_args
.
push_back
(
tv
->
var
);
thread_head_check
.
push_back
(
tv
->
var
==
0
);
}
}
}
// Checks for the thread.
std
::
vector
<
Expr
>
thread_head_check
;
if
(
stage
->
store_predicate
.
defined
())
{
thread_head_check
.
emplace_back
(
stage
->
store_predicate
);
}
Type
t
=
reduce
->
type
;
Expr
pred
=
const_true
(
t
.
lanes
());
Stmt
reduce_body
=
Store
::
make
(
res_handle
,
...
...
@@ -311,6 +314,9 @@ Stmt ComputeOpNode::BuildProvide(
nest
.
push_back
(
op
::
MakeIfNest
(
op
::
MakeBoundCheck
(
stage
,
dom_map
,
false
,
std
::
unordered_set
<
IterVar
>
(),
value_map
)));
if
(
stage
->
store_predicate
.
defined
())
{
nest
.
emplace_back
(
op
::
MakeIfNest
({
stage
->
store_predicate
}));
}
provide
=
Substitute
(
provide
,
value_map
);
if
(
init
.
defined
())
{
...
...
src/schedule/schedule_lang.cc
View file @
bf8a5c07
...
...
@@ -200,6 +200,12 @@ Stage& Stage::env_threads(Array<IterVar> threads) {
return
*
this
;
}
Stage
&
Stage
::
set_store_predicate
(
Expr
predicate
)
{
StageNode
*
self
=
operator
->
();
self
->
store_predicate
=
predicate
;
return
*
this
;
}
Stage
&
Stage
::
split
(
IterVar
parent
,
Expr
factor
,
IterVar
*
p_outer
,
IterVar
*
p_inner
)
{
// NOLINT(*)
Split
(
operator
->
(),
parent
,
factor
,
Expr
(),
p_outer
,
p_inner
);
...
...
tests/python/integration/test_reduce.py
View file @
bf8a5c07
...
...
@@ -98,8 +98,10 @@ def test_rfactor_threads():
s
[
B
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
B
]
.
bind
(
ty
,
tvm
.
thread_axis
(
"threadIdx.y"
))
tx
=
s
[
B
]
.
op
.
reduce_axis
[
0
]
s
[
B
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
thread_x
=
tvm
.
thread_axis
(
"threadIdx.x"
)
s
[
B
]
.
bind
(
tx
,
thread_x
)
s
[
BF
]
.
compute_at
(
s
[
B
],
tx
)
s
[
B
]
.
set_store_predicate
(
thread_x
.
var
.
equal
(
0
))
# one line to build the function.
def
check_target
(
device
,
host
=
"stackvm"
):
...
...
tests/python/perf/lstm.py
0 → 100644
View file @
bf8a5c07
"""LSTM Example, still work in progress.."""
import
tvm
import
time
import
os
import
argparse
from
tvm.contrib
import
nvcc_compiler
import
numpy
as
np
# Quick knobs
TASK
=
"lstm"
USE_MANUAL_CODE
=
False
PERSIST_KERNEL
=
True
DETECT_GLOBAL_BARRIER
=
PERSIST_KERNEL
SKIP_CHECK
=
False
UNROLL_WLOAD
=
True
@tvm.register_func
def
tvm_callback_cuda_compile
(
code
):
"""Use nvcc compiler for better perf."""
ptx
=
nvcc_compiler
.
compile_source
(
code
,
target
=
"ptx"
,
options
=
[
"-arch=sm_52"
])
return
ptx
def
write_code
(
code
,
fname
):
with
open
(
fname
,
"w"
)
as
f
:
f
.
write
(
code
)
@tvm.register_func
def
tvm_callback_cuda_postproc
(
code
):
if
not
os
.
path
.
exists
(
"perf"
):
os
.
mkdir
(
"perf"
)
write_code
(
code
,
"perf/
%
s_generated.cu"
%
TASK
)
if
USE_MANUAL_CODE
:
code
=
open
(
"perf/
%
s_manual.cu"
%
TASK
)
.
read
()
return
code
def
lstm
():
if
not
PERSIST_KERNEL
:
raise
ValueError
(
"Non persist LSTM not yet supported"
)
detect_global_barrier
=
DETECT_GLOBAL_BARRIER
num_thread_y
=
8
num_thread_x
=
16
*
3
/
2
num_sm
=
24
n_num_step
=
128
num_step
=
tvm
.
var
(
'num_step'
)
num_hidden
=
1152
/
2
batch_size
=
1
# Global transition matrix
# Input hidden channel can be pre-caculated by a gemm
Xi2h
=
tvm
.
placeholder
((
num_step
,
batch_size
,
4
,
num_hidden
),
name
=
"Xi2h"
)
# Only handle hidden transition, saves space.
Wh2h
=
tvm
.
placeholder
((
4
,
num_hidden
,
num_hidden
),
name
=
"Wh2h"
)
# h: output hidden state, c: cell state.
s_state_h
=
tvm
.
placeholder
((
num_step
,
batch_size
,
num_hidden
))
s_state_c
=
tvm
.
placeholder
((
num_step
,
batch_size
,
num_hidden
))
s_init_c
=
tvm
.
compute
((
1
,
batch_size
,
num_hidden
),
lambda
*
i
:
0.0
,
name
=
"init_c"
)
s_init_h
=
tvm
.
compute
((
1
,
batch_size
,
num_hidden
),
lambda
*
i
:
0.0
,
name
=
"init_h"
)
# LSTM transition
k
=
tvm
.
reduce_axis
((
0
,
num_hidden
),
name
=
"ki2h"
)
s_h2h
=
tvm
.
compute
(
(
num_step
,
batch_size
,
4
,
num_hidden
),
lambda
t
,
i
,
x
,
j
:
tvm
.
sum
(
s_state_h
[
t
-
1
,
i
,
k
]
*
Wh2h
[
x
,
j
,
k
],
axis
=
k
),
name
=
"s_h2h"
)
# Gate rules
gates
=
tvm
.
compute
(
Xi2h
.
shape
,
lambda
*
i
:
Xi2h
(
*
i
)
+
s_h2h
(
*
i
),
name
=
"gates"
)
gshape
=
(
num_step
,
batch_size
,
num_hidden
)
in_gate
=
tvm
.
compute
(
gshape
,
lambda
t
,
i
,
j
:
tvm
.
sigmoid
(
gates
[
t
,
i
,
0
,
j
]),
name
=
"in_gate"
)
in_transform
=
tvm
.
compute
(
gshape
,
lambda
t
,
i
,
j
:
tvm
.
tanh
(
gates
[
t
,
i
,
1
,
j
]),
name
=
"in_transform"
)
forget_gate
=
tvm
.
compute
(
gshape
,
lambda
t
,
i
,
j
:
tvm
.
sigmoid
(
gates
[
t
,
i
,
2
,
j
]),
name
=
"forget_gate"
)
out_gate
=
tvm
.
compute
(
gshape
,
lambda
t
,
i
,
j
:
tvm
.
sigmoid
(
gates
[
t
,
i
,
3
,
j
]),
name
=
"out_gate"
)
next_c
=
tvm
.
compute
(
gshape
,
lambda
t
,
i
,
j
:
forget_gate
[
t
,
i
,
j
]
*
s_state_c
[
t
-
1
,
i
,
j
]
+
in_gate
[
t
,
i
,
j
]
*
in_transform
[
t
,
i
,
j
],
name
=
"next_c"
)
next_h
=
tvm
.
compute
(
gshape
,
lambda
t
,
i
,
j
:
out_gate
[
t
,
i
,
j
]
*
tvm
.
tanh
(
next_c
[
t
,
i
,
j
]),
name
=
"next_h"
)
update_c
=
tvm
.
compute
(
gshape
,
lambda
*
i
:
next_c
(
*
i
),
name
=
"update_c"
)
update_h
=
tvm
.
compute
(
gshape
,
lambda
*
i
:
next_h
(
*
i
),
name
=
"update_h"
)
# schedule
scan_h
,
scan_c
=
tvm
.
scan
(
[
s_init_h
,
s_init_c
],
[
update_h
,
update_c
],
[
s_state_h
,
s_state_c
],
inputs
=
[
Xi2h
],
name
=
"lstm_scan"
)
# schedule
s
=
tvm
.
create_schedule
(
scan_h
.
op
)
# Inline gate computations
s
[
gates
]
.
compute_inline
()
s
[
in_gate
]
.
compute_inline
()
s
[
in_transform
]
.
compute_inline
()
s
[
forget_gate
]
.
compute_inline
()
s
[
out_gate
]
.
compute_inline
()
block_x
=
tvm
.
thread_axis
((
0
,
num_sm
),
"blockIdx.x"
)
thread_x
=
tvm
.
thread_axis
((
0
,
num_thread_x
),
"threadIdx.x"
)
thread_y
=
tvm
.
thread_axis
((
0
,
num_thread_y
),
"threadIdx.y"
)
s_state_h_S
=
s
.
cache_read
(
s_state_h
,
"shared"
,
[
s_h2h
])
s_state_c_S
=
s
.
cache_read
(
s_state_c
,
"shared"
,
[
next_c
])
Wh2hL
=
s
.
cache_read
(
Wh2h
,
"local"
,
[
s_h2h
])
ko
,
ki
=
s
[
s_h2h
]
.
split
(
s
[
s_h2h
]
.
op
.
reduce_axis
[
0
],
nparts
=
num_thread_y
)
s_h2h_rf
=
s
.
rfactor
(
s_h2h
,
ko
)
s
[
s_h2h
]
.
bind
(
s
[
s_h2h
]
.
op
.
reduce_axis
[
0
],
thread_y
)
s
[
s_h2h_rf
]
.
compute_at
(
s
[
s_h2h
],
s
[
s_h2h
]
.
op
.
reduce_axis
[
0
])
if
PERSIST_KERNEL
:
s
[
scan_h
.
op
]
.
env_threads
([
block_x
,
thread_y
,
thread_x
])
s
[
Wh2hL
]
.
compute_at
(
s
[
scan_h
.
op
],
thread_x
)
else
:
s
[
Wh2hL
]
.
compute_at
(
s
[
s_h2h
],
s
[
s_h2h
]
.
op
.
axis
[
3
])
if
UNROLL_WLOAD
:
s
[
Wh2hL
]
.
unroll
(
Wh2hL
.
op
.
axis
[
0
])
s
[
Wh2hL
]
.
unroll
(
Wh2hL
.
op
.
axis
[
2
])
s
[
s_state_h_S
]
.
compute_at
(
s
[
s_h2h_rf
],
s
[
s_h2h_rf
]
.
op
.
axis
[
3
])
s
[
s_state_c_S
]
.
compute_at
(
s
[
scan_h
.
op
],
s
[
scan_h
]
.
op
.
scan_axis
)
for
ss
in
[
s_state_h_S
]:
xo
,
xi
=
s
[
ss
]
.
split
(
ss
.
op
.
axis
[
2
],
factor
=
num_thread_x
*
num_thread_y
)
ty
,
xi
=
s
[
ss
]
.
split
(
xi
,
nparts
=
num_thread_y
)
tx
,
xi
=
s
[
ss
]
.
split
(
xi
,
nparts
=
num_thread_x
)
s
[
ss
]
.
bind
(
ty
,
thread_y
)
s
[
ss
]
.
bind
(
tx
,
thread_x
)
for
init
in
[
s_init_c
,
s_init_h
]:
bx
,
xi
=
s
[
init
]
.
split
(
init
.
op
.
axis
[
2
],
nparts
=
num_sm
)
tx
,
xi
=
s
[
init
]
.
split
(
xi
,
nparts
=
num_thread_x
)
s
[
init
]
.
bind
(
bx
,
block_x
)
s
[
init
]
.
bind
(
tx
,
thread_x
)
s
[
next_c
]
.
set_store_predicate
(
thread_y
.
equal
(
0
))
s
[
next_h
]
.
set_store_predicate
(
thread_y
.
equal
(
0
))
for
update
in
[
update_c
,
update_h
]:
bx
,
xi
=
s
[
update
]
.
split
(
s
[
update
]
.
op
.
axis
[
2
],
nparts
=
num_sm
)
tx
,
xi
=
s
[
update
]
.
split
(
xi
,
nparts
=
num_thread_x
)
s
[
update
]
.
bind
(
bx
,
block_x
)
s
[
update
]
.
bind
(
tx
,
thread_x
)
s
[
update
]
.
set_store_predicate
(
thread_y
.
equal
(
0
))
# verify we can lower correctly
def
check_device
(
target
):
num_step
=
n_num_step
flstm
=
tvm
.
build
(
s
,
[
Xi2h
,
Wh2h
,
scan_h
,
scan_c
],
target
,
detect_global_barrier
=
DETECT_GLOBAL_BARRIER
)
ctx
=
tvm
.
gpu
(
0
)
if
target
==
"cuda"
else
tvm
.
cl
(
0
)
# launch the kernel.
scan_h_np
=
np
.
zeros
(
(
num_step
,
batch_size
,
num_hidden
))
.
astype
(
"float32"
)
scan_c_np
=
np
.
zeros
(
(
num_step
,
batch_size
,
num_hidden
))
.
astype
(
"float32"
)
Xi2h_np
=
np
.
random
.
normal
(
size
=
(
num_step
,
batch_size
,
4
,
num_hidden
))
.
astype
(
"float32"
)
Wh2h_np
=
np
.
random
.
normal
(
size
=
(
4
,
num_hidden
,
num_hidden
))
.
astype
(
"float32"
)
scan_h_a
=
tvm
.
nd
.
array
(
scan_h_np
,
ctx
)
scan_c_a
=
tvm
.
nd
.
array
(
scan_c_np
,
ctx
)
Xi2h_a
=
tvm
.
nd
.
array
(
Xi2h_np
,
ctx
)
Wh2h_a
=
tvm
.
nd
.
array
(
Wh2h_np
,
ctx
)
flstm
(
Xi2h_a
,
Wh2h_a
,
scan_h_a
,
scan_c_a
)
ctx
.
sync
()
# measure time cost of second step.
tstart
=
time
.
time
()
flstm
(
Xi2h_a
,
Wh2h_a
,
scan_h_a
,
scan_c_a
)
ctx
.
sync
()
tgap
=
time
.
time
()
-
tstart
print
(
"Time cost=
%
g"
%
tgap
)
check_device
(
"cuda"
)
if
__name__
==
"__main__"
:
lstm
()
tests/python/perf/rnn_matexp.py
View file @
bf8a5c07
...
...
@@ -90,6 +90,8 @@ def rnn_matexp():
s
[
s_update
]
.
bind
(
tx
,
thread_x
)
s
[
CL
]
.
bind
(
s
[
CL
]
.
op
.
reduce_axis
[
0
],
thread_y
)
s
[
CLF
]
.
compute_at
(
s
[
CL
],
s
[
CL
]
.
op
.
reduce_axis
[
0
])
# Duplicate store predicate.
s
[
CL
]
.
set_store_predicate
(
thread_y
.
equal
(
0
))
if
PERSIST_KERNEL
:
s
[
WhhL
]
.
compute_at
(
s
[
s_scan
],
thread_x
)
...
...
@@ -109,7 +111,6 @@ def rnn_matexp():
s
[
SS
]
.
bind
(
tx
,
thread_x
)
def
check_device
(
target
):
codes
=
[]
f
=
tvm
.
build
(
s
,
[
s_scan
,
Whh
],
target
,
max_auto_unroll_step
=
max_auto_unroll_step
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment