Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
979623e5
Commit
979623e5
authored
May 12, 2017
by
Tianqi Chen
Committed by
GitHub
May 12, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Tutorial] External Tensor Op (#137)
parent
553657eb
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
261 additions
and
47 deletions
+261
-47
docs/api/python/contrib.rst
+6
-0
docs/api/python/schedule.rst
+3
-0
python/tvm/api.py
+44
-5
python/tvm/build.py
+45
-20
python/tvm/schedule.py
+12
-1
tests/python/unittest/test_pass_loop_partition.py
+23
-2
tests/python/unittest/test_schedule_lstm.py
+1
-1
tests/python/unittest/test_schedule_schedule_ops.py
+2
-2
tutorials/python/extern_op.py
+109
-0
tutorials/python/reduction.py
+4
-4
tutorials/python/schedule_primitives.py
+12
-12
No files found.
docs/api/python/contrib.rst
View file @
979623e5
...
...
@@ -16,3 +16,9 @@ tvm.contrib.util
~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.util
:members:
tvm.contrib.cblas
~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.cblas
:members:
docs/api/python/schedule.rst
View file @
979623e5
...
...
@@ -5,6 +5,9 @@ tvm.schedule
.. autoclass:: tvm.schedule.IterVar
:members:
.. autoclass:: tvm.schedule.Buffer
:members:
.. autofunction:: tvm.create_schedule
.. autoclass:: tvm.schedule.Schedule
...
...
python/tvm/api.py
View file @
979623e5
...
...
@@ -236,14 +236,13 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"):
res
=
[
op
.
output
(
i
)
for
i
in
range
(
len
(
update
))]
return
res
[
0
]
if
len
(
res
)
==
1
else
res
def
extern
(
shape
,
inputs
,
fcompute
,
name
=
"extern"
,
dtype
=
None
):
"""Compute several tensor via extern function.
Parameters
----------
shape:
Shape tuple or list of shap
es.
shape:
tuple or list of tupl
es.
The shape of the outputs.
inputs: list of Tensor
...
...
@@ -251,6 +250,17 @@ def extern(shape, inputs, fcompute,
fcompute: lambda function of inputs, outputs-> stmt
Specifies the IR statement to do the computation.
See the following note for function signature of fcompute
.. note::
**Parameters**
- **ins** (list of :any:`Buffer`) - Placeholder for each inputs
- **outs** (list of :any:`Buffer`) - Placeholder for each outputs
**Returns**
- **stmt** (:any:`Stmt`) - The statement that carries out array computation.
name: str, optional
The name hint of the tensor
...
...
@@ -263,9 +273,23 @@ def extern(shape, inputs, fcompute,
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors it it contains multiple outputs.
Example
-------
In the code below, C is generated by calling external PackedFunc
`tvm.contrib.cblas.matmul`
.. code-block:: python
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = tvm.extern((n, m), [A, B],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], 0, 0), name="C")
"""
if
isinstance
(
shape
[
0
],
_expr
.
Expr
):
shape
=
[
shape
]
shape
=
(
shape
,)
if
isinstance
(
shape
,
(
_expr
.
Expr
,
_Integral
))
else
shape
shape
=
[
shape
]
if
isinstance
(
shape
[
0
],
(
_expr
.
Expr
,
_Integral
))
else
shape
input_placeholders
=
[]
output_placeholders
=
[]
types
=
set
()
...
...
@@ -305,6 +329,8 @@ def decl_buffer(shape, dtype=None,
Normally buffer is created automatically during lower and build.
This is only needed if user want to specify their own buffer layout.
See the note below for detailed discussion on usage of buffer.
Parameters
----------
shape : tuple of Expr
...
...
@@ -332,8 +358,21 @@ def decl_buffer(shape, dtype=None,
-------
buffer : Buffer
The created buffer
Note
----
Buffer data structure reflects the DLTensor structure in dlpack.
While DLTensor data structure is very general, it is usually helpful
to create function that only handles specific case of data structure
and make compiled function benefit from it.
If user pass strides and byte_offset is passed as None
when constructing the function, then the function will be specialized
for the DLTensor that is compact and aligned.
If user pass a fully generic symbolic array to the strides,
then the resulting function becomes fully generic.
"""
shape
=
(
shape
,)
if
isinstance
(
shape
,
_expr
.
Expr
)
else
shape
shape
=
(
shape
,)
if
isinstance
(
shape
,
(
_expr
.
Expr
,
_Integral
)
)
else
shape
dtype
=
float32
if
dtype
is
None
else
dtype
strides
=
()
if
strides
is
None
else
strides
if
data
is
None
:
...
...
python/tvm/build.py
View file @
979623e5
...
...
@@ -13,12 +13,47 @@ from . import collections
from
.
import
module
from
.
import
codegen
def
get_binds
(
args
,
binds
=
None
):
"""Internal function to get binds and arg_list given arguments.
Parameters
----------
args : list of Buffer or Tensor or Var
The argument lists to the function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
Returns
-------
binds: dict
The bind specification
arg_list: list
The list of symbolic buffers of arguments.
"""
binds
=
{}
if
binds
is
None
else
binds
.
copy
()
arg_list
=
[]
for
x
in
args
:
if
isinstance
(
x
,
tensor
.
Tensor
):
buf
=
api
.
decl_buffer
(
x
.
shape
,
dtype
=
x
.
dtype
,
name
=
x
.
name
)
assert
x
not
in
binds
binds
[
x
]
=
buf
arg_list
.
append
(
buf
)
elif
isinstance
(
x
,
schedule
.
Buffer
):
arg_list
.
append
(
x
)
elif
isinstance
(
x
,
expr
.
Var
):
arg_list
.
append
(
x
)
else
:
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
return
binds
,
arg_list
def
lower
(
sch
,
args
,
name
=
"default_function"
,
binds
=
None
,
with_api_wrapper
=
Tru
e
,
simple_mode
=
Fals
e
,
max_auto_unroll_step
=
0
):
"""Lowering step before build into target.
...
...
@@ -37,8 +72,9 @@ def lower(sch,
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
with_api_wrapper : bool, optional
Whether add API wrapper during lowering.
simple_mode : bool, optional
Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling.
max_auto_unroll_step: int, optional
Maximum step to perform automatic unrolling
...
...
@@ -49,33 +85,22 @@ def lower(sch,
The result function, if with_api_wrapper=False
Then the Stmt before make api is returned.
"""
binds
=
{}
if
binds
is
None
else
binds
.
copy
()
arg_list
=
[]
for
x
in
args
:
if
isinstance
(
x
,
tensor
.
Tensor
):
buf
=
api
.
decl_buffer
(
x
.
shape
,
dtype
=
x
.
dtype
,
name
=
x
.
name
)
assert
x
not
in
binds
binds
[
x
]
=
buf
arg_list
.
append
(
buf
)
elif
isinstance
(
x
,
schedule
.
Buffer
):
arg_list
.
append
(
x
)
elif
isinstance
(
x
,
expr
.
Var
):
arg_list
.
append
(
x
)
else
:
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
binds
,
arg_list
=
get_binds
(
args
,
binds
)
# normalize schedule first
sch
=
sch
.
normalize
()
bounds
=
schedule
.
InferBound
(
sch
)
stmt
=
schedule
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
ir_pass
.
LoopPartition
(
stmt
)
if
not
simple_mode
:
stmt
=
ir_pass
.
LoopPartition
(
stmt
)
stmt
=
ir_pass
.
StorageFlatten
(
stmt
,
binds
)
stmt
=
ir_pass
.
CanonicalSimplify
(
stmt
)
stmt
=
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
ir_pass
.
InjectVirtualThread
(
stmt
)
stmt
=
ir_pass
.
StorageRewrite
(
stmt
)
stmt
=
ir_pass
.
UnrollLoop
(
stmt
,
max_auto_unroll_step
)
if
not
simple_mode
:
stmt
=
ir_pass
.
UnrollLoop
(
stmt
,
max_auto_unroll_step
)
stmt
=
ir_pass
.
Simplify
(
stmt
)
if
not
with_api_wrapper
:
if
simple_mode
:
return
stmt
return
ir_pass
.
MakeAPI
(
stmt
,
name
,
arg_list
,
0
)
...
...
python/tvm/schedule.py
View file @
979623e5
...
...
@@ -10,7 +10,18 @@ from ._ffi.function import _init_api
@register_node
class
Buffer
(
NodeBase
):
"""Represent a symbolic buffer in TVM."""
"""Symbolic data buffer in TVM.
Buffer provide a way to represent data layout
specialization of data structure in TVM.
Do not construct directly, use :any:`decl_buffer` instead.
See the documentation of :any:`decl_buffer` for more details.
See Also
--------
decl_buffer : Declare a buffer
"""
pass
@register_node
...
...
tests/python/unittest/test_pass_loop_partition.py
View file @
979623e5
...
...
@@ -5,6 +5,27 @@ def collect_visit(stmt, f):
tvm
.
ir_pass
.
PostOrderVisit
(
stmt
,
lambda
x
:
ret
.
append
(
f
(
x
)))
return
ret
def
lower
(
sch
,
args
):
binds
=
{}
arg_list
=
[]
for
x
in
args
:
if
isinstance
(
x
,
tvm
.
tensor
.
Tensor
):
buf
=
tvm
.
decl_buffer
(
x
.
shape
,
dtype
=
x
.
dtype
,
name
=
x
.
name
)
assert
x
not
in
binds
binds
[
x
]
=
buf
arg_list
.
append
(
buf
)
else
:
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
sch
=
sch
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
sch
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
tvm
.
ir_pass
.
LoopPartition
(
stmt
)
stmt
=
tvm
.
ir_pass
.
StorageFlatten
(
stmt
,
binds
)
stmt
=
tvm
.
ir_pass
.
CanonicalSimplify
(
stmt
)
stmt
=
tvm
.
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
tvm
.
ir_pass
.
Simplify
(
stmt
)
return
stmt
def
test_basic
():
n
=
tvm
.
var
(
'n'
)
A
=
tvm
.
placeholder
((
n
,
),
name
=
'A'
)
...
...
@@ -92,7 +113,7 @@ def test_vectorize():
s
[
C
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
C
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
C
]
.
vectorize
(
x
)
stmt
=
tvm
.
lower
(
s
,
[
A
,
B
],
name
=
'ewise_add'
,
with_api_wrapper
=
False
)
stmt
=
lower
(
s
,
[
A
,
B
]
)
body
=
stmt
.
body
.
body
.
body
.
body
.
body
assert
(
x
.
var
.
name
not
in
str
(
body
.
condition
))
assert
(
any
(
collect_visit
(
body
.
then_case
,
lambda
x
:
isinstance
(
x
,
tvm
.
expr
.
Ramp
))))
...
...
@@ -123,7 +144,7 @@ def test_thread_axis2():
_
,
x
=
s
[
C
]
.
split
(
x
,
factor
=
m
)
s
[
C
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
C
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
stmt
=
tvm
.
lower
(
s
,
[
A
,
B
],
name
=
'ewise_add'
,
with_api_wrapper
=
False
)
stmt
=
lower
(
s
,
[
A
,
B
]
)
for_body
=
stmt
.
body
.
body
.
body
.
body
.
body
.
first
assert
(
'threadIdx'
not
in
str
(
for_body
.
extent
))
...
...
tests/python/unittest/test_schedule_lstm.py
View file @
979623e5
...
...
@@ -59,7 +59,7 @@ def test_lstm_cell_inline():
s
[
forget_gate
]
.
compute_inline
()
s
[
out_gate
]
.
compute_inline
()
# verify we can lower correctly
tvm
.
lower
(
s
,
[
X
,
Wi2h
,
Wh2h
,
scan_h
,
scan_c
]
,
with_api_wrapper
=
False
)
tvm
.
lower
(
s
,
[
X
,
Wi2h
,
Wh2h
,
scan_h
,
scan_c
])
if
__name__
==
"__main__"
:
test_lstm_cell_inline
()
tests/python/unittest/test_schedule_schedule_ops.py
View file @
979623e5
...
...
@@ -109,7 +109,7 @@ def test_scan_inline1():
[
s_state1
,
s_state2
])
s
=
tvm
.
create_schedule
(
res1
.
op
)
s
[
s_x1
]
.
compute_inline
()
stmt
=
tvm
.
lower
(
s
,
[
x
,
res1
,
res2
]
,
with_api_wrapper
=
False
)
stmt
=
tvm
.
lower
(
s
,
[
x
,
res1
,
res2
])
def
test_scan_inline2
():
m
=
tvm
.
var
(
"m"
)
...
...
@@ -131,7 +131,7 @@ def test_scan_inline2():
s
[
s_xx
]
.
compute_inline
()
s
[
s_x1
]
.
compute_inline
()
s
[
s_x2
]
.
compute_inline
()
stmt
=
tvm
.
lower
(
s
,
[
x
,
res1
,
res2
]
,
with_api_wrapper
=
False
)
stmt
=
tvm
.
lower
(
s
,
[
x
,
res1
,
res2
])
def
test_schedule_cache
():
...
...
tutorials/python/extern_op.py
0 → 100644
View file @
979623e5
"""
External Tensor Functions
=========================
**Author**: `Tianqi Chen <https://tqchen.github.io>`_
While tvm support transparent code generation, sometimes
it is also helpful to incorporate manual written code into
the pipeline. For example, we might want to use cuDNN for
some of the convolution kernels and define the rest of the stages.
TVM support these black box function calls natively.
Specfically, tvm support all the tensor functions that are DLPack compatible.
Which means we can call any function with POD types(pointer, int, float)
or pointer to DLTensor as argument.
"""
from
__future__
import
absolute_import
,
print_function
import
tvm
import
numpy
as
np
from
tvm.contrib
import
cblas
######################################################################
# Use Extern Tensor Function
# --------------------------
# In the example below, we use :any:`tvm.extern` to add an extern
# array function call. In the extern call, we declare the shape
# of output tensors. In the second argument we provide the list of inputs.
#
# User will need to provide a function describing how to compute the result.
# The compute function takes list of symbolic are placeholder for the inputs,
# list of symbolic placeholder for the outputs and returns the executing statement.
#
# In this case we simply call a registered tvm function, which invokes a CBLAS call.
# TVM do not control internal of the extern array function and treats it as blackbox.
# We can further mix schedulable TVM calls that add a bias to term to the result.
#
n
=
1024
l
=
128
m
=
235
bias
=
tvm
.
var
(
'bias'
,
dtype
=
tvm
.
float32
)
A
=
tvm
.
placeholder
((
n
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
l
,
m
),
name
=
'B'
)
C
=
tvm
.
extern
((
n
,
m
),
[
A
,
B
],
lambda
ins
,
outs
:
tvm
.
call_packed
(
"tvm.contrib.cblas.matmul"
,
ins
[
0
],
ins
[
1
],
outs
[
0
],
False
,
False
),
name
=
"C"
)
D
=
tvm
.
compute
(
C
.
shape
,
lambda
i
,
j
:
C
[
i
,
j
]
+
bias
,
name
=
"D"
)
s
=
tvm
.
create_schedule
(
D
.
op
)
######################################################################
# Verify the Result
# -----------------
# We can verify that the result matches what we expected.
#
ctx
=
tvm
.
cpu
(
0
)
f
=
tvm
.
build
(
s
,
[
A
,
B
,
D
,
bias
],
"llvm"
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,
l
))
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
l
,
m
))
.
astype
(
B
.
dtype
),
ctx
)
d
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
D
.
dtype
),
ctx
)
bb
=
10.0
f
(
a
,
b
,
d
,
bb
)
np
.
testing
.
assert_allclose
(
d
.
asnumpy
(),
np
.
dot
(
a
.
asnumpy
(),
b
.
asnumpy
())
+
10
)
######################################################################
# Extern Contrib Wrappers
# -----------------------
# TVM also provide extern contrib wrappers to useful extern calls,
# the following line is equivalent to the previous example.
#
from
tvm.contrib
import
cblas
C
=
cblas
.
matmul
(
A
,
B
)
D
=
tvm
.
compute
(
C
.
shape
,
lambda
i
,
j
:
C
[
i
,
j
]
+
bias
,
name
=
"D"
)
s
=
tvm
.
create_schedule
(
D
.
op
)
######################################################################
# Hook Python Function as Extern
# ------------------------------
# Since we can call into any PackedFunc in TVM. We can use the extern
# function to callback into python.
#
# The following example registers a python function into tvm runtime system
# and use it to complete one stage of the computation.
# This makes TVM much more flexible. For example, we can insert front-end
# callbacks to inspect the intermediate results or mix customized code
# with TVM.
#
@tvm.register_func
(
"tvm.contrib.my_tvm_addone"
)
def
my_tvm_addone
(
x
,
y
):
print
(
"my_tvm_addone signatures:
%
s,
%
s"
%
(
type
(
x
),
type
(
y
)))
tvm
.
nd
.
array
(
x
.
asnumpy
()
+
1
)
.
copyto
(
y
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
extern
(
A
.
shape
,
[
A
],
lambda
ins
,
outs
:
tvm
.
call_packed
(
"tvm.contrib.my_tvm_addone"
,
ins
[
0
],
outs
[
0
]),
name
=
"C"
)
s
=
tvm
.
create_schedule
(
B
.
op
)
f
=
tvm
.
build
(
s
,
[
A
,
B
],
"llvm"
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,))
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,))
.
astype
(
B
.
dtype
),
ctx
)
f
(
a
,
b
)
np
.
testing
.
assert_allclose
(
b
.
asnumpy
(),
a
.
asnumpy
()
+
1
)
######################################################################
# Summary
# -------
# - TVM call extern tensor function via :any:`tvm.extern`
# - Use contrib wrappers for short sugars of extern tensor calls.
# - We can hook front-end function as extern tensor callbacks.
#
tutorials/python/reduction.py
View file @
979623e5
...
...
@@ -50,7 +50,7 @@ B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
# Before doing anything, let us print out the IR code of default schedule.
#
s
=
tvm
.
create_schedule
(
B
.
op
)
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
Tru
e
))
######################################################################
# You can find that the IR code is quite like the C code.
...
...
@@ -61,13 +61,13 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
#
ko
,
ki
=
s
[
B
]
.
split
(
B
.
op
.
reduce_axis
[
0
],
factor
=
16
)
xo
,
xi
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
32
)
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
Tru
e
))
######################################################################
# If we are building a GPU kernel, we can bind the rows of B to GPU threads.
s
[
B
.
op
]
.
bind
(
xo
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
B
.
op
]
.
bind
(
xi
,
tvm
.
thread_axis
(
"threadIdx.x"
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
Tru
e
))
######################################################################
# Reduction Factoring and Parallelization
...
...
@@ -84,7 +84,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
s
=
tvm
.
create_schedule
(
B
.
op
)
ko
,
ki
=
s
[
B
]
.
split
(
B
.
op
.
reduce_axis
[
0
],
factor
=
16
)
BF
=
s
.
rfactor
(
B
,
ki
)
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
Tru
e
))
######################################################################
# The scheduled operator of B also get rewritten to be sum over
...
...
tutorials/python/schedule_primitives.py
View file @
979623e5
...
...
@@ -39,10 +39,10 @@ C = tvm.compute((m, n), lambda i, j: A[i, j] * B[i, j], name='C')
s
=
tvm
.
create_schedule
([
C
.
op
])
# lower will transform the computation from definition to the real
# callable function. With argument `
with_api_wrapper=Fals
e`, it will
# callable function. With argument `
simple_mode=Tru
e`, it will
# return you a readable C like statement, we use it here to print the
# schedule result.
print
(
tvm
.
lower
(
s
,
[
A
,
B
,
C
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
,
C
],
simple_mode
=
Tru
e
))
######################################################################
# One schedule is composed by multiple stages, and one
...
...
@@ -59,7 +59,7 @@ B = tvm.compute((m,), lambda i: A[i]*2, name='B')
s
=
tvm
.
create_schedule
(
B
.
op
)
xo
,
xi
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
32
)
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
Tru
e
))
######################################################################
# You can also split a axis by :code:`nparts`, which splits the axis
...
...
@@ -69,7 +69,7 @@ B = tvm.compute((m,), lambda i: A[i], name='B')
s
=
tvm
.
create_schedule
(
B
.
op
)
bx
,
tx
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
nparts
=
32
)
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
Tru
e
))
######################################################################
# tile
...
...
@@ -81,7 +81,7 @@ B = tvm.compute((m, n), lambda i, j: A[i, j], name='B')
s
=
tvm
.
create_schedule
(
B
.
op
)
xo
,
yo
,
xi
,
yi
=
s
[
B
]
.
tile
(
B
.
op
.
axis
[
0
],
B
.
op
.
axis
[
1
],
x_factor
=
10
,
y_factor
=
5
)
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
Tru
e
))
######################################################################
# fuse
...
...
@@ -95,7 +95,7 @@ s = tvm.create_schedule(B.op)
xo
,
yo
,
xi
,
yi
=
s
[
B
]
.
tile
(
B
.
op
.
axis
[
0
],
B
.
op
.
axis
[
1
],
x_factor
=
10
,
y_factor
=
5
)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
fused
=
s
[
B
]
.
fuse
(
yi
,
xi
)
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
Tru
e
))
######################################################################
# reorder
...
...
@@ -109,7 +109,7 @@ s = tvm.create_schedule(B.op)
xo
,
yo
,
xi
,
yi
=
s
[
B
]
.
tile
(
B
.
op
.
axis
[
0
],
B
.
op
.
axis
[
1
],
x_factor
=
10
,
y_factor
=
5
)
# then reorder the axises: (i.inner, j.outer, i.outer, j.inner)
s
[
B
]
.
reorder
(
xi
,
yo
,
xo
,
yi
)
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
Tru
e
))
######################################################################
# bind
...
...
@@ -123,7 +123,7 @@ s = tvm.create_schedule(B.op)
bx
,
tx
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
64
)
s
[
B
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
B
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
],
simple_mode
=
Tru
e
))
######################################################################
# compute_at
...
...
@@ -135,7 +135,7 @@ B = tvm.compute((m,), lambda i: A[i]+1, name='B')
C
=
tvm
.
compute
((
m
,),
lambda
i
:
B
[
i
]
*
2
,
name
=
'C'
)
s
=
tvm
.
create_schedule
(
C
.
op
)
print
(
tvm
.
lower
(
s
,
[
A
,
B
,
C
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
,
C
],
simple_mode
=
Tru
e
))
######################################################################
# :code:`compute_at` can move computation of `B` into the first axis
...
...
@@ -146,7 +146,7 @@ C = tvm.compute((m,), lambda i: B[i]*2, name='C')
s
=
tvm
.
create_schedule
(
C
.
op
)
s
[
B
]
.
compute_at
(
s
[
C
],
C
.
op
.
axis
[
0
])
print
(
tvm
.
lower
(
s
,
[
A
,
B
,
C
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
,
C
],
simple_mode
=
Tru
e
))
######################################################################
# compute_inline
...
...
@@ -160,7 +160,7 @@ C = tvm.compute((m,), lambda i: B[i]*2, name='C')
s
=
tvm
.
create_schedule
(
C
.
op
)
s
[
B
]
.
compute_inline
()
print
(
tvm
.
lower
(
s
,
[
A
,
B
,
C
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
,
C
],
simple_mode
=
Tru
e
))
######################################################################
# compute_root
...
...
@@ -173,7 +173,7 @@ C = tvm.compute((m,), lambda i: B[i]*2, name='C')
s
=
tvm
.
create_schedule
(
C
.
op
)
s
[
B
]
.
compute_at
(
s
[
C
],
C
.
op
.
axis
[
0
])
s
[
B
]
.
compute_root
()
print
(
tvm
.
lower
(
s
,
[
A
,
B
,
C
],
with_api_wrapper
=
Fals
e
))
print
(
tvm
.
lower
(
s
,
[
A
,
B
,
C
],
simple_mode
=
Tru
e
))
######################################################################
# Summary
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment