Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
1c04f389
Commit
1c04f389
authored
Apr 22, 2017
by
Tianqi Chen
Committed by
GitHub
Apr 22, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[DEV/IR] Python IRBuilder (#102)
parent
d17b10f0
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
447 additions
and
53 deletions
+447
-53
HalideIR
+1
-1
docs/api/python/dev.rst
+7
-3
python/tvm/__init__.py
+1
-0
python/tvm/_ctypes/_function.py
+2
-2
python/tvm/_ctypes/_node.py
+8
-5
python/tvm/addon/__init__.py
+4
-1
python/tvm/addon/verilog.py
+1
-1
python/tvm/expr.py
+3
-4
python/tvm/intrin.py
+0
-1
python/tvm/ir_builder.py
+333
-0
python/tvm/stmt.py
+4
-0
python/tvm/tensor.py
+6
-2
tests/python/unittest/test_codegen_vm_basic.py
+18
-19
tests/python/unittest/test_ir_builder.py
+42
-0
tests/python/unittest/test_pass_loop_partition.py
+17
-14
No files found.
HalideIR
@
398edacd
Subproject commit
d024efd80694556c1239c4435c5b3e70853a4896
Subproject commit
398edacd956c6de82185821ffd9f482598182e51
docs/api/python/dev.rst
View file @
1c04f389
...
...
@@ -35,9 +35,6 @@ tvm.stmt
tvm.ir_pass
~~~~~~~~~~~
.. automodule:: tvm.ir_pass
:members:
.. autosummary::
tvm.ir_pass.Inline
...
...
@@ -58,6 +55,13 @@ tvm.ir_pass
tvm.ir_pass.LowerThreadAllreduce
tvm.ir_pass.NarrowChannelAccess
.. automodule:: tvm.ir_pass
:members:
tvm.ir_builder
~~~~~~~~~~~~~~
.. automodule:: tvm.ir_builder
:members:
tvm.make
~~~~~~~~
...
...
python/tvm/__init__.py
View file @
1c04f389
...
...
@@ -13,6 +13,7 @@ from . import collections
from
.
import
schedule
from
.
import
module
from
.
import
node
from
.
import
ir_builder
from
.
import
ndarray
as
nd
from
.ndarray
import
cpu
,
gpu
,
opencl
,
cl
,
vpi
...
...
python/tvm/_ctypes/_function.py
View file @
1c04f389
...
...
@@ -13,7 +13,7 @@ from .._base import c_str, py_str, string_types
from
._types
import
TVMValue
,
TypeCode
,
TVMType
,
TVMByteArray
from
._types
import
TVMPackedCFunc
,
TVMCFuncFinalizer
from
._types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
from
._node
import
NodeBase
,
SliceBase
,
convert_to_node
from
._node
import
NodeBase
,
NodeGeneric
,
convert_to_node
from
._ndarray
import
NDArrayBase
FunctionHandle
=
ctypes
.
c_void_p
...
...
@@ -114,7 +114,7 @@ def _make_tvm_args(args, temp_args):
elif
isinstance
(
arg
,
string_types
):
values
[
i
]
.
v_str
=
c_str
(
arg
)
type_codes
[
i
]
=
TypeCode
.
STR
elif
isinstance
(
arg
,
(
list
,
tuple
,
dict
,
SliceBase
)):
elif
isinstance
(
arg
,
(
list
,
tuple
,
dict
,
NodeGeneric
)):
arg
=
convert_to_node
(
arg
)
values
[
i
]
.
v_handle
=
arg
.
handle
type_codes
[
i
]
=
TypeCode
.
NODE_HANDLE
...
...
python/tvm/_ctypes/_node.py
View file @
1c04f389
...
...
@@ -41,9 +41,12 @@ C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
_return_node
,
TypeCode
.
NODE_HANDLE
)
class
SliceBase
(
object
):
"""base class of slice object"""
pass
class
NodeGeneric
(
object
):
"""Base class for all classes that can be converted to node."""
def
asnode
(
self
):
"""Convert value to node"""
raise
NotImplementedError
()
class
NodeBase
(
object
):
"""NodeBase is the base class of all TVM language AST object."""
...
...
@@ -176,8 +179,8 @@ def convert_to_node(value):
vlist
.
append
(
it
[
0
])
vlist
.
append
(
convert_to_node
(
it
[
1
]))
return
_api_internal
.
_Map
(
*
vlist
)
elif
isinstance
(
value
,
SliceBase
):
return
value
.
tensor
(
*
value
.
indices
)
elif
isinstance
(
value
,
NodeGeneric
):
return
value
.
asnode
(
)
else
:
raise
ValueError
(
"don't know how to convert type
%
s to node"
%
type
(
value
))
...
...
python/tvm/addon/__init__.py
View file @
1c04f389
"""Addon utilities to python"""
"""Addon utilities to TVM python package.
These features are useful to have not not essential to TVM.
"""
python/tvm/addon/verilog.py
View file @
1c04f389
"""
Information about nnvm
."""
"""
Verilog simulator modules
."""
from
__future__
import
absolute_import
import
subprocess
...
...
python/tvm/expr.py
View file @
1c04f389
...
...
@@ -50,6 +50,9 @@ class ExprOp(object):
def
__rtruediv__
(
self
,
other
):
return
self
.
__rdiv__
(
other
)
def
__mod__
(
self
,
other
):
return
_make
.
Mod
(
self
,
other
)
def
__neg__
(
self
):
return
self
.
__mul__
(
-
1
)
...
...
@@ -118,10 +121,6 @@ class Cast(Expr):
pass
@register_node
class
Variable
(
Expr
):
pass
@register_node
class
Add
(
BinaryOpExpr
):
pass
...
...
python/tvm/intrin.py
View file @
1c04f389
...
...
@@ -57,7 +57,6 @@ def call_pure_extern(dtype, func_name, *args):
return
_make
.
Call
(
dtype
,
func_name
,
convert
(
args
),
_Call
.
PureExtern
,
None
,
0
)
def
exp
(
x
):
"""Take exponetial of input x.
...
...
python/tvm/ir_builder.py
0 → 100644
View file @
1c04f389
"""Developer API of IR node builder make function."""
from
__future__
import
absolute_import
as
_abs
from
.
import
api
as
_api
from
.
import
stmt
as
_stmt
from
.
import
expr
as
_expr
from
.
import
make
as
_make
from
.
import
ir_pass
as
_pass
from
.
import
collections
as
_collections
from
._base
import
string_types
from
._ctypes._node
import
NodeGeneric
class
WithScope
(
object
):
"""Auxiliary scope with"""
def
__init__
(
self
,
enter_value
,
exit_cb
):
self
.
_enter_value
=
enter_value
self
.
_exit_cb
=
exit_cb
def
__enter__
(
self
):
return
self
.
_enter_value
def
__exit__
(
self
,
ptype
,
value
,
trace
):
self
.
_exit_cb
()
class
BufferVar
(
NodeGeneric
):
"""Buffer variable with content type, makes load store easily.
Do not create it directly, create use IRBuilder.
Examples
--------
In the follow example, x is BufferVar.
:code:`x[0] = ...` directly emit a store to the IRBuilder,
:code:`x[10]` translates to Load.
.. code-block:: python
# The following code generate IR for x[0] = x[
ib = tvm.ir_builder.create()
x = ib.pointer("float32")
x[0] = x[10] + 1
See Also
--------
IRBuilder.pointer
IRBuilder.buffer_ptr
IRBuilder.allocate
"""
def
__init__
(
self
,
builder
,
buffer_var
,
content_type
):
self
.
_builder
=
builder
self
.
_buffer_var
=
buffer_var
self
.
_content_type
=
content_type
def
asnode
(
self
):
return
self
.
_buffer_var
def
__getitem__
(
self
,
index
):
return
_make
.
Load
(
self
.
_content_type
,
self
.
_buffer_var
,
index
)
def
__setitem__
(
self
,
index
,
value
):
value
=
_api
.
convert
(
value
)
if
value
.
dtype
!=
self
.
_content_type
:
raise
ValueError
(
"data type does not match content type
%
s vs
%
s"
%
(
value
.
dtype
,
self
.
_content_type
))
self
.
_builder
.
emit
(
_make
.
Store
(
self
.
_buffer_var
,
value
,
index
))
class
IRBuilder
(
object
):
"""Auxiliary builder to build IR for testing and dev.
Examples
--------
.. code-block:: python
ib = tvm.ir_builder.create()
n = tvm.var("n")
A = ib.allocate("float32", n, name="A")
with ib.for_range(0, n, name="i") as i:
with ib.if_scope((i
% 2
) == 0):
A[i] = A[i] + 1
# The result stmt.
stmt = ib.get()
"""
def
__init__
(
self
):
self
.
_seq_stack
=
[[]]
def
_pop_seq
(
self
):
"""Pop sequence from stack"""
seq
=
self
.
_seq_stack
.
pop
()
if
len
(
seq
)
==
0
or
callable
(
seq
[
-
1
]):
seq
.
append
(
_make
.
Evaluate
(
0
))
stmt
=
seq
[
-
1
]
for
s
in
reversed
(
seq
[:
-
1
]):
if
callable
(
s
):
stmt
=
s
(
stmt
)
else
:
assert
isinstance
(
s
,
_stmt
.
Stmt
)
stmt
=
_make
.
Block
(
s
,
stmt
)
return
stmt
def
emit
(
self
,
stmt
):
"""Emit a statement to the end of current scope.
Parameters
----------
stmt : Stmt or callable.
The statement to be emitted or callable that build stmt given body.
"""
if
isinstance
(
stmt
,
_expr
.
Call
):
stmt
=
_make
.
Evaluate
(
stmt
)
assert
isinstance
(
stmt
,
_stmt
.
Stmt
)
or
callable
(
stmt
)
self
.
_seq_stack
[
-
1
]
.
append
(
stmt
)
def
scope_attr
(
self
,
node
,
attr_key
,
value
):
"""Create an AttrStmt at current scope.
Parameters
----------
attr_key : str
The key of the attribute type.
node : Node
The attribute node to annottate on.
value : Expr
Attribute value.
Examples
--------
.. code-block:: python
ib = tvm.ir_builder.create()
i = tvm.var("i")
x = ib.pointer("float32")
ib.scope_attr(x, "storage_scope", "global")
x[i] = x[i - 1] + 1
"""
if
isinstance
(
node
,
string_types
):
node
=
_make
.
StringImm
(
node
)
if
isinstance
(
value
,
string_types
):
value
=
_make
.
StringImm
(
value
)
self
.
emit
(
lambda
x
:
_make
.
AttrStmt
(
node
,
attr_key
,
value
,
x
))
def
for_range
(
self
,
begin
,
end
,
name
=
"i"
,
dtype
=
"int32"
):
"""Create a for iteration scope.
Parameters
----------
begin : Expr
The min iteration scope.
end : Expr
The end iteration scope
name : str, optional
The name of iteration variable
dtype : str, optional
The data type of iteration variable.
Returns
-------
loop_scope : With.Scope of Var
The for scope, when enters returns loop_var
Examples
--------
.. code-block:: python
ib = tvm.ir_builder.create()
x = ib.pointer("float32")
with ib.for_range(1, 10, name="i") as i:
x[i] = x[i - 1] + 1
"""
self
.
_seq_stack
.
append
([])
loop_var
=
_api
.
var
(
name
,
dtype
=
dtype
)
extent
=
end
if
begin
==
0
else
_pass
.
Simplify
(
end
-
begin
)
def
_exit_cb
():
self
.
emit
(
_make
.
For
(
loop_var
,
begin
,
extent
,
0
,
0
,
self
.
_pop_seq
()))
return
WithScope
(
loop_var
,
_exit_cb
)
def
if_scope
(
self
,
cond
):
"""Create an if scope.
Parameters
----------
cond : Expr
The condition.
Returns
-------
if_scope : WithScope
The result if scope.
Examples
--------
.. code-block:: python
ib = tvm.ir_builder.create()
i = tvm.var("i")
x = ib.pointer("float32")
with ib.if_scope((i
% 2
) == 0):
x[i] = x[i - 1] + 1
"""
self
.
_seq_stack
.
append
([])
def
_exit_cb
():
self
.
emit
(
_make
.
IfThenElse
(
cond
,
self
.
_pop_seq
(),
None
))
return
WithScope
(
None
,
_exit_cb
)
def
else_scope
(
self
):
"""Create an else scope.
This can only be used right after an if scope.
Returns
-------
else_scope : WithScope
The result else scope.
Examples
--------
.. code-block:: python
ib = tvm.ir_builder.create()
i = tvm.var("i")
x = ib.pointer("float32")
with ib.if_scope((i
% 2
) == 0):
x[i] = x[i - 1] + 1
with ib.else_scope():
x[i] = x[i - 1] + 2
"""
if
len
(
self
.
_seq_stack
[
-
1
])
==
0
:
raise
RuntimeError
(
"else_scope can only follow an if_scope"
)
prev
=
self
.
_seq_stack
[
-
1
][
-
1
]
if
not
isinstance
(
prev
,
_stmt
.
IfThenElse
)
or
prev
.
else_case
:
raise
RuntimeError
(
"else_scope can only follow an if_scope"
)
self
.
_seq_stack
[
-
1
]
.
pop
()
self
.
_seq_stack
.
append
([])
def
_exit_cb
():
self
.
emit
(
_make
.
IfThenElse
(
prev
.
condition
,
prev
.
then_case
,
self
.
_pop_seq
()))
return
WithScope
(
None
,
_exit_cb
)
def
allocate
(
self
,
dtype
,
shape
,
name
=
"buf"
,
scope
=
None
):
"""Create a allocate statement.
Parameters
----------
dtype : str
The content data type.
shape : tuple of Expr
The shape of array to be allocated.
name : str, optional
The name of the buffer.
scope : str, optional
The scope of the buffer.
Returns
-------
buffer : BufferVar
The buffer var representing the buffer.
"""
buffer_var
=
_api
.
var
(
name
,
dtype
=
"handle"
)
if
not
isinstance
(
shape
,
(
list
,
tuple
,
_collections
.
Array
)):
shape
=
[
shape
]
if
scope
:
self
.
scope_attr
(
buffer_var
,
"storage_scope"
,
scope
)
self
.
emit
(
lambda
x
:
_make
.
Allocate
(
buffer_var
,
dtype
,
shape
,
_api
.
const
(
1
,
dtype
=
"uint1"
),
x
))
return
BufferVar
(
self
,
buffer_var
,
dtype
)
def
pointer
(
self
,
content_type
,
name
=
"ptr"
):
"""Create pointer variable with content type.
Parameters
----------
content_type : str
The content data type.
name : str, optional
The name of the pointer.
Returns
-------
ptr : BufferVar
The buffer var representing the buffer.
"""
buffer_var
=
_api
.
var
(
name
,
dtype
=
"handle"
)
return
BufferVar
(
self
,
buffer_var
,
content_type
)
def
buffer_ptr
(
self
,
buf
):
"""Create pointer variable corresponds to buffer ptr.
Parameters
----------
buf : Buffer
The buffer to be extracted.
Returns
-------
ptr : BufferVar
The buffer var representing the buffer.
"""
return
BufferVar
(
self
,
buf
.
data
,
buf
.
dtype
)
def
get
(
self
):
"""Return the builded IR.
Returns
-------
stmt : Stmt
The result statement.
"""
seq
=
self
.
_pop_seq
()
if
len
(
self
.
_seq_stack
)
!=
0
:
raise
RuntimeError
(
"cannot call get inside construction scope"
)
return
seq
def
create
():
"""Create a new IRBuilder
Returns
-------
builder : IRBuilder
The created IRBuilder
"""
return
IRBuilder
()
python/tvm/stmt.py
View file @
1c04f389
...
...
@@ -51,6 +51,10 @@ class Allocate(Stmt):
pass
@register_node
class
AttrStmt
(
Stmt
):
pass
@register_node
class
Free
(
Stmt
):
pass
...
...
python/tvm/tensor.py
View file @
1c04f389
"""Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name
from
__future__
import
absolute_import
as
_abs
from
._ctypes._node
import
NodeBase
,
SliceBase
,
register_node
,
convert_to_node
from
._ctypes._node
import
NodeBase
,
NodeGeneric
,
register_node
,
convert_to_node
from
.
import
_api_internal
from
.
import
make
as
_make
from
.
import
expr
as
_expr
class
TensorSlice
(
SliceBase
,
_expr
.
ExprOp
):
class
TensorSlice
(
NodeGeneric
,
_expr
.
ExprOp
):
"""Auxiliary data structure for enable slicing syntax from tensor."""
def
__init__
(
self
,
tensor
,
indices
):
if
not
isinstance
(
indices
,
tuple
):
...
...
@@ -19,6 +19,10 @@ class TensorSlice(SliceBase, _expr.ExprOp):
indices
=
(
indices
,)
return
TensorSlice
(
self
.
tensor
,
self
.
indices
+
indices
)
def
asnode
(
self
):
"""Convert slice to node."""
return
self
.
tensor
(
*
self
.
indices
)
@property
def
dtype
(
self
):
"""Data content of the tensor."""
...
...
tests/python/unittest/test_codegen_vm_basic.py
View file @
1c04f389
...
...
@@ -28,20 +28,19 @@ def test_stack_vm_basic():
def
tvm_stack_vm_print
(
*
x
):
print
(
x
)
def
test_stack_vm_loop
():
dtype
=
'int64'
n
=
tvm
.
var
(
'n'
)
Ab
=
tvm
.
decl_buffer
((
n
,
),
dtype
)
i
=
tvm
.
var
(
'i'
)
# for i in 0 to n-1:
stmt
=
tvm
.
make
.
For
(
i
,
0
,
n
-
1
,
0
,
0
,
tvm
.
make
.
Block
(
tvm
.
make
.
Store
(
Ab
.
data
,
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
)
+
1
,
i
+
1
),
tvm
.
make
.
Evaluate
(
tvm
.
call_packed
(
"tvm_stack_vm_print"
,
i
)))
)
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
buffer_ptr
(
Ab
)
with
ib
.
for_range
(
0
,
n
-
1
,
"i"
)
as
i
:
A
[
i
+
1
]
=
A
[
i
]
+
1
ib
.
emit
(
tvm
.
call_packed
(
"tvm_stack_vm_print"
,
i
))
stmt
=
ib
.
get
(
)
fapi
=
tvm
.
ir_pass
.
MakeAPI
(
stmt
,
"ramp"
,
[
Ab
],
0
)
a
=
tvm
.
nd
.
array
(
np
.
zeros
(
10
,
dtype
=
dtype
))
def
check
(
f
):
...
...
@@ -54,16 +53,16 @@ def test_stack_vm_cond():
dtype
=
'int64'
n
=
tvm
.
var
(
'n'
)
Ab
=
tvm
.
decl_buffer
((
n
,
),
dtype
)
i
=
tvm
.
var
(
'i'
)
# for i in 0 to n-1:
stmt
=
tvm
.
make
.
For
(
i
,
0
,
n
-
1
,
0
,
0
,
tvm
.
make
.
IfThenElse
(
tvm
.
make
.
EQ
(
i
,
4
),
tvm
.
make
.
Store
(
Ab
.
data
,
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
)
+
1
,
i
+
1
),
tvm
.
make
.
Store
(
Ab
.
data
,
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
)
+
2
,
i
+
1
))
)
ib
=
tvm
.
ir_builder
.
create
()
A
=
ib
.
buffer_ptr
(
Ab
)
with
ib
.
for_range
(
0
,
n
-
1
,
"i"
)
as
i
:
with
ib
.
if_scope
(
tvm
.
make
.
EQ
(
i
,
4
)):
A
[
i
+
1
]
=
A
[
i
]
+
1
with
ib
.
else_scope
():
A
[
i
+
1
]
=
A
[
i
]
+
2
stmt
=
ib
.
get
(
)
fapi
=
tvm
.
ir_pass
.
MakeAPI
(
stmt
,
"test"
,
[
Ab
],
0
)
def
check
(
f
):
a
=
tvm
.
nd
.
array
(
np
.
zeros
(
10
,
dtype
=
dtype
))
...
...
tests/python/unittest/test_ir_builder.py
0 → 100644
View file @
1c04f389
import
tvm
def
test_for
():
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"n"
)
A
=
ib
.
allocate
(
"float32"
,
n
,
name
=
"A"
,
scope
=
"global"
)
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
A
[
i
]
=
A
[
i
]
+
1
with
ib
.
for_range
(
0
,
10
,
name
=
"j"
)
as
j
:
A
[
j
]
=
A
[
j
]
+
2
body
=
ib
.
get
()
print
(
body
)
assert
isinstance
(
body
,
tvm
.
stmt
.
AttrStmt
)
body
=
body
.
body
assert
isinstance
(
body
,
tvm
.
stmt
.
Allocate
)
body
=
body
.
body
assert
isinstance
(
body
,
tvm
.
stmt
.
For
)
body
=
body
.
body
assert
isinstance
(
body
,
tvm
.
stmt
.
Block
)
assert
isinstance
(
body
.
rest
,
tvm
.
stmt
.
For
)
def
test_if
():
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"n"
)
A
=
ib
.
pointer
(
"float32"
,
name
=
"A"
)
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
with
ib
.
if_scope
((
i
%
2
)
==
0
):
A
[
i
]
=
A
[
i
]
+
1
with
ib
.
else_scope
():
A
[
0
]
=
A
[
i
]
+
2
body
=
ib
.
get
()
assert
isinstance
(
body
,
tvm
.
stmt
.
For
)
body
=
body
.
body
assert
isinstance
(
body
,
tvm
.
stmt
.
IfThenElse
)
assert
isinstance
(
body
.
then_case
.
index
,
tvm
.
expr
.
Var
)
assert
body
.
else_case
.
index
.
value
==
0
if
__name__
==
"__main__"
:
test_if
()
test_for
()
tests/python/unittest/test_pass_loop_partition.py
View file @
1c04f389
import
tvm
def
collect_visit
(
stmt
,
f
):
ret
=
[]
tvm
.
ir_pass
.
PostOrderVisit
(
stmt
,
lambda
x
:
ret
.
append
(
f
(
x
)))
return
ret
def
test_basic
():
n
=
tvm
.
var
(
'n'
)
A
=
tvm
.
placeholder
((
n
,
),
name
=
'A'
)
...
...
@@ -16,22 +21,20 @@ def test_basic():
print
(
stmt
)
def
test_multi_loop
():
i
=
tvm
.
var
(
'i'
)
j
=
tvm
.
var
(
'j'
)
k
=
tvm
.
var
(
'k'
)
ib
=
tvm
.
ir_builder
.
create
()
m
=
tvm
.
var
(
'm'
)
n
=
tvm
.
var
(
'n'
)
stmt
=
tvm
.
make
.
For
(
i
,
0
,
4
,
0
,
0
,
tvm
.
make
.
For
(
j
,
0
,
n
,
0
,
0
,
tvm
.
make
.
For
(
k
,
0
,
m
,
0
,
0
,
tvm
.
make
.
IfThenElse
(
(
i
*
m
+
j
+
k
<
n
),
tvm
.
make
.
Evaluate
(
m
),
tvm
.
make
.
Evaluate
(
n
))))
)
with
ib
.
for_range
(
0
,
4
,
"i"
)
as
i
:
with
ib
.
for_range
(
0
,
n
,
"j"
)
as
j
:
with
ib
.
for_range
(
0
,
m
,
"k"
)
as
k
:
with
ib
.
if_scope
(
i
*
m
+
j
+
k
<
n
):
ib
.
emit
(
tvm
.
make
.
Evaluate
(
m
))
with
ib
.
else_scope
():
ib
.
emit
(
tvm
.
make
.
Evaluate
(
n
))
stmt
=
ib
.
get
(
)
stmt
=
tvm
.
ir_pass
.
LoopPartition
(
stmt
)
assert
(
'if'
not
in
str
(
stmt
.
body
.
first
))
print
(
stmt
)
assert
(
not
any
(
collect_visit
(
stmt
.
body
.
first
,
lambda
x
:
isinstance
(
x
,
tvm
.
stmt
.
IfThenElse
)))
)
def
test_multi_if
():
i
=
tvm
.
var
(
'i'
)
...
...
@@ -74,7 +77,7 @@ def test_thread_axis():
print
(
stmt_
)
if
__name__
==
"__main__"
:
test_basic
()
test_multi_loop
()
test_basic
()
test_multi_if
()
test_thread_axis
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment