Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f1438813
Unverified
Commit
f1438813
authored
Apr 15, 2020
by
Tianqi Chen
Committed by
GitHub
Apr 15, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PYTHON] Enhance with_attr API, cleanup MakeAPILegacy in testcases (#5335)
parent
d81b006b
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
117 additions
and
141 deletions
+117
-141
python/tvm/ir/function.py
+31
-0
python/tvm/relay/function.py
+0
-19
python/tvm/testing.py
+0
-37
python/tvm/tir/function.py
+0
-19
src/ir/function.cc
+26
-0
src/ir/module.cc
+7
-1
src/relay/ir/function.cc
+0
-6
src/tir/ir/function.cc
+0
-6
tests/python/unittest/test_runtime_extension.py
+2
-1
tests/python/unittest/test_runtime_module_load.py
+5
-2
tests/python/unittest/test_target_codegen_llvm.py
+8
-4
tests/python/unittest/test_target_codegen_static_init.py
+8
-15
tests/python/unittest/test_target_codegen_vm_basic.py
+14
-20
tests/python/unittest/test_tir_nodes.py
+1
-1
tests/python/unittest/test_tir_pass_storage_flatten.py
+3
-1
tests/python/unittest/test_tir_transform_lower_warp_memory.py
+1
-1
tests/python/unittest/test_tir_transform_make_packed_api.py
+5
-5
tests/python/unittest/test_tir_transform_thread_sync.py
+6
-3
No files found.
python/tvm/ir/function.py
View file @
f1438813
...
...
@@ -16,6 +16,8 @@
# under the License.
"""Function defintiions."""
from
enum
import
IntEnum
import
tvm.runtime
from
.expr
import
RelayExpr
from
.
import
_ffi_api
...
...
@@ -34,3 +36,32 @@ class BaseFunc(RelayExpr):
"""Return the attrs member of the function.
"""
return
_ffi_api
.
BaseFunc_Attrs
(
self
)
def
with_attr
(
self
,
attr_key_or_dict
,
attr_value
=
None
):
"""Create a new copy of the function and update the attribute.
Parameters
----------
attr_key_or_dict : Union[str, dict]
The attribute key to use or a dict containing multiple key value pairs.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
# make sure we first copy so that we can safely do copy on write
# for multiple updates.
res
=
_ffi_api
.
BaseFuncCopy
(
self
)
if
isinstance
(
attr_key_or_dict
,
dict
):
for
key
,
val
in
attr_key_or_dict
.
items
():
res
=
_ffi_api
.
BaseFuncWithAttr
(
res
.
_move
(),
key
,
tvm
.
runtime
.
convert
(
val
))
return
res
return
_ffi_api
.
BaseFuncWithAttr
(
res
.
_move
(),
attr_key_or_dict
,
tvm
.
runtime
.
convert
(
attr_value
))
python/tvm/relay/function.py
View file @
f1438813
...
...
@@ -65,22 +65,3 @@ class Function(BaseFunc):
Arguments.
"""
return
Call
(
self
,
args
,
None
,
None
)
def
with_attr
(
self
,
attr_key
,
attr_value
):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return
_ffi_api
.
FunctionWithAttr
(
self
,
attr_key
,
convert
(
attr_value
))
python/tvm/testing.py
View file @
f1438813
...
...
@@ -168,41 +168,4 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
x_name
,
grad
.
shape
,
dist
,
max_diff
,
avg_diff
)
def
MakeAPILegacy
(
stmt
,
name
,
args
,
num_unpacked_args
,
noalias
):
"""Legacy adapter to build a Module from statement.
Used for migrating existing test cases only.
Parameters
----------
stmt: Stmt
The input statement.
name: str
The name of the funciton.
args: list of Buffer or Vars
The function arguments
num_unpacked_args: int
Number of unpacked arguments.
nolias: bool
Whether allow noalias.
Returns
-------
mod : IRModule
The created IRModule.
"""
assert
num_unpacked_args
==
0
f
=
tvm
.
tir
.
PrimFunc
(
args
,
stmt
)
.
with_attr
(
"global_symbol"
,
tvm
.
runtime
.
String
(
name
))
f
=
f
.
with_attr
(
"tir.is_entry_func"
,
True
)
if
noalias
:
f
=
f
.
with_attr
(
"tir.noalias"
,
True
)
mod
=
tvm
.
IRModule
({
name
:
f
})
return
mod
tvm
.
_ffi
.
_init_api
(
"testing"
,
__name__
)
python/tvm/tir/function.py
View file @
f1438813
...
...
@@ -67,22 +67,3 @@ class PrimFunc(BaseFunc):
self
.
__init_handle_by_constructor__
(
_ffi_api
.
PrimFunc
,
param_list
,
body
,
ret_type
,
buffer_map
,
attrs
)
def
with_attr
(
self
,
attr_key
,
attr_value
):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return
_ffi_api
.
PrimFuncWithAttr
(
self
,
attr_key
,
tvm
.
runtime
.
convert
(
attr_value
))
src/ir/function.cc
View file @
f1438813
...
...
@@ -23,6 +23,14 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/function.h>
// NOTE: reverse dependency on relay, tir/
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: We calls into the type specific WithAttr function
#include <tvm/tir/function.h>
#include <tvm/relay/function.h>
namespace
tvm
{
...
...
@@ -31,4 +39,22 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs")
return
func
->
attrs
;
});
TVM_REGISTER_GLOBAL
(
"ir.BaseFuncCopy"
)
.
set_body_typed
([](
BaseFunc
func
)
{
return
func
;
});
TVM_REGISTER_GLOBAL
(
"ir.BaseFuncWithAttr"
)
.
set_body_typed
([](
BaseFunc
func
,
std
::
string
key
,
ObjectRef
value
)
->
BaseFunc
{
if
(
func
->
IsInstance
<
tir
::
PrimFuncNode
>
())
{
return
WithAttr
(
Downcast
<
tir
::
PrimFunc
>
(
std
::
move
(
func
)),
key
,
value
);
}
else
if
(
func
->
IsInstance
<
relay
::
FunctionNode
>
())
{
return
WithAttr
(
Downcast
<
relay
::
Function
>
(
std
::
move
(
func
)),
key
,
value
);
}
else
{
LOG
(
FATAL
)
<<
"Do not support function type "
<<
func
->
GetTypeKey
();
return
func
;
}
});
}
// namespace tvm
src/ir/module.cc
View file @
f1438813
...
...
@@ -362,13 +362,19 @@ IRModule IRModule::FromExpr(
const
tvm
::
Map
<
GlobalTypeVar
,
TypeData
>&
type_definitions
)
{
auto
mod
=
IRModule
(
global_funcs
,
type_definitions
);
BaseFunc
func
;
std
::
string
gv_name
=
"main"
;
if
(
auto
*
func_node
=
expr
.
as
<
BaseFuncNode
>
())
{
func
=
GetRef
<
BaseFunc
>
(
func_node
);
if
(
auto
opt
=
func
->
GetAttr
<
String
>
(
tvm
::
attr
::
kGlobalSymbol
))
{
gv_name
=
opt
.
value
();
}
}
else
{
func
=
relay
::
Function
(
relay
::
FreeVars
(
expr
),
expr
,
Type
(),
relay
::
FreeTypeVars
(
expr
,
mod
),
{});
}
auto
main_gv
=
GlobalVar
(
"main"
);
auto
main_gv
=
GlobalVar
(
gv_name
);
mod
->
Add
(
main_gv
,
func
);
return
mod
;
}
...
...
src/relay/ir/function.cc
View file @
f1438813
...
...
@@ -74,11 +74,5 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<<
node
->
attrs
<<
")"
;
});
TVM_REGISTER_GLOBAL
(
"relay.ir.FunctionWithAttr"
)
.
set_body_typed
(
[](
Function
func
,
std
::
string
name
,
ObjectRef
ref
)
{
return
WithAttr
(
std
::
move
(
func
),
name
,
ref
);
});
}
// namespace relay
}
// namespace tvm
src/tir/ir/function.cc
View file @
f1438813
...
...
@@ -84,11 +84,5 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc")
return
PrimFunc
(
params
,
body
,
ret_type
,
buffer_map
,
attrs
);
});
TVM_REGISTER_GLOBAL
(
"tir.PrimFuncWithAttr"
)
.
set_body_typed
([](
PrimFunc
func
,
std
::
string
name
,
ObjectRef
ref
)
{
return
WithAttr
(
std
::
move
(
func
),
name
,
ref
);
});
}
// namespace tir
}
// namespace tvm
tests/python/unittest/test_runtime_extension.py
View file @
f1438813
...
...
@@ -39,7 +39,8 @@ def test_dltensor_compatible():
A
[
i
+
1
]
=
A
[
i
]
+
1
stmt
=
ib
.
get
()
mod
=
tvm
.
testing
.
MakeAPILegacy
(
stmt
,
"arange"
,
[
Ab
],
0
,
True
)
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
Ab
],
stmt
)
.
with_attr
(
"global_symbol"
,
"arange"
))
f
=
tvm
.
build
(
mod
,
target
=
"stackvm"
)
a
=
tvm
.
nd
.
array
(
np
.
zeros
(
10
,
dtype
=
dtype
))
aview
=
MyTensorView
(
a
)
...
...
tests/python/unittest/test_runtime_module_load.py
View file @
f1438813
...
...
@@ -57,8 +57,11 @@ def test_dso_module_load():
tvm
.
tir
.
Store
(
Ab
.
data
,
tvm
.
tir
.
Load
(
dtype
,
Ab
.
data
,
i
)
+
1
,
i
+
1
))
m
=
tvm
.
testing
.
MakeAPILegacy
(
stmt
,
"ramp"
,
[
Ab
],
0
,
True
)
m
=
tvm
.
driver
.
build
(
m
,
target
=
"llvm"
)
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
Ab
],
stmt
)
.
with_attr
(
"global_symbol"
,
"main"
)
)
m
=
tvm
.
driver
.
build
(
mod
,
target
=
"llvm"
)
for
name
in
names
:
m
.
save
(
name
)
...
...
tests/python/unittest/test_target_codegen_llvm.py
View file @
f1438813
...
...
@@ -36,8 +36,11 @@ def test_llvm_intrin():
"int32"
,
"prefetch"
,
args
,
tvm
.
tir
.
Call
.
Intrinsic
,
None
,
0
)))
body
=
ib
.
get
()
func
=
tvm
.
testing
.
MakeAPILegacy
(
body
,
"prefetch"
,
[
A
],
0
,
True
)
fcode
=
tvm
.
build
(
func
,
None
,
"llvm"
)
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
A
],
body
)
.
with_attr
(
"global_symbol"
,
"prefetch"
)
)
fcode
=
tvm
.
build
(
mod
,
None
,
"llvm"
)
def
test_llvm_overloaded_intrin
():
...
...
@@ -111,8 +114,9 @@ def test_llvm_lookup_intrin():
x
=
tvm
.
tir
.
call_llvm_intrin
(
"uint8x8"
,
"llvm.ctpop.v8i8"
,
tvm
.
tir
.
const
(
1
,
'uint32'
),
A
[
z
])
ib
.
emit
(
x
)
body
=
ib
.
get
()
func
=
tvm
.
testing
.
MakeAPILegacy
(
body
,
"ctpop"
,
[
A
],
0
,
True
)
fcode
=
tvm
.
build
(
func
,
None
,
"llvm"
)
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
A
],
body
)
.
with_attr
(
"global_symbol"
,
"main"
))
fcode
=
tvm
.
build
(
mod
,
None
,
"llvm"
)
def
test_llvm_large_uintimm
():
...
...
tests/python/unittest/test_target_codegen_static_init.py
View file @
f1438813
...
...
@@ -20,17 +20,6 @@ import ctypes
import
numpy
as
np
def
MakeAPILegacy
(
stmt
,
name
,
args
,
num_unpacked_args
,
noalias
):
"""Legacy adapter to create a API"""
f
=
tvm
.
tir
.
PrimFunc
(
args
,
stmt
)
.
with_attr
(
"global_symbol"
,
tvm
.
runtime
.
String
(
name
))
f
=
f
.
with_attr
(
"tir.is_entry_func"
,
True
)
if
noalias
:
f
=
f
.
with_attr
(
"tir.noalias"
,
True
)
mod
=
tvm
.
IRModule
.
from_expr
(
f
)
return
tvm
.
tir
.
transform
.
MakePackedAPI
()(
mod
)
def
test_static_callback
():
dtype
=
'int64'
n
=
te
.
size_var
(
'n'
)
...
...
@@ -44,8 +33,11 @@ def test_static_callback():
with
ib
.
for_range
(
0
,
n
,
"i"
,
for_type
=
"parallel"
)
as
i
:
A
[
i
]
=
A
[
i
]
+
1
stmt
=
ib
.
get
()
fapi
=
tvm
.
testing
.
MakeAPILegacy
(
stmt
,
"ramp"
,
[
Ab
],
0
,
True
)
f
=
tvm
.
driver
.
build
(
fapi
,
target
=
"llvm"
)
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
Ab
],
stmt
)
.
with_attr
(
"global_symbol"
,
"ramp"
)
)
f
=
tvm
.
driver
.
build
(
mod
,
target
=
"llvm"
)
a
=
tvm
.
nd
.
array
(
np
.
zeros
(
10
,
dtype
=
dtype
))
f
(
a
)
f
(
a
)
...
...
@@ -67,8 +59,9 @@ def test_static_init():
return
sh
stmt
=
ib
.
get
()
fapi
=
tvm
.
testing
.
MakeAPILegacy
(
stmt
,
"ramp"
,
[
Ab
],
0
,
True
)
f
=
tvm
.
driver
.
build
(
fapi
,
target
=
"llvm"
)
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
Ab
],
stmt
)
.
with_attr
(
"global_symbol"
,
"ramp"
))
f
=
tvm
.
driver
.
build
(
mod
,
target
=
"llvm"
)
a
=
tvm
.
nd
.
array
(
np
.
zeros
(
10
,
dtype
=
dtype
))
f
(
a
)
...
...
tests/python/unittest/test_target_codegen_vm_basic.py
View file @
f1438813
...
...
@@ -26,18 +26,6 @@ def run_jit(fapi, check):
s
=
f
.
get_source
()
check
(
f
)
def
MakeAPILegacy
(
stmt
,
name
,
args
,
num_unpacked_args
,
noalias
):
"""Legacy adapter to create a API"""
f
=
tvm
.
tir
.
PrimFunc
(
args
,
stmt
)
.
with_attr
(
"global_symbol"
,
tvm
.
runtime
.
String
(
name
))
f
=
f
.
with_attr
(
"tir.is_entry_func"
,
True
)
if
noalias
:
f
=
f
.
with_attr
(
"tir.noalias"
,
True
)
mod
=
tvm
.
IRModule
.
from_expr
(
f
)
return
tvm
.
tir
.
transform
.
MakePackedAPI
()(
mod
)
def
test_stack_vm_basic
():
a
=
tvm
.
nd
.
array
(
np
.
zeros
(
10
,
dtype
=
'float32'
))
@tvm.register_func
...
...
@@ -48,8 +36,11 @@ def test_stack_vm_basic():
n
=
te
.
size_var
(
'n'
)
Ab
=
tvm
.
tir
.
decl_buffer
((
n
,
),
"float32"
)
stmt
=
tvm
.
tir
.
Evaluate
(
tvm
.
tir
.
call_packed
(
"tvm_call_back_get_shape"
,
Ab
.
shape
[
0
]))
fapi
=
tvm
.
testing
.
MakeAPILegacy
(
stmt
,
"print_shape"
,
[
Ab
],
0
,
True
)
run_jit
(
fapi
,
lambda
f
:
f
(
a
))
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
Ab
],
stmt
)
.
with_attr
(
"global_symbol"
,
"print_shape"
))
run_jit
(
mod
,
lambda
f
:
f
(
a
))
@tvm.register_func
...
...
@@ -69,12 +60,13 @@ def test_stack_vm_loop():
ib
.
emit
(
tvm
.
tir
.
call_packed
(
"tvm_stack_vm_print"
,
i
))
stmt
=
ib
.
get
()
fapi
=
tvm
.
testing
.
MakeAPILegacy
(
stmt
,
"ramp"
,
[
Ab
],
0
,
True
)
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
Ab
],
stmt
)
.
with_attr
(
"global_symbol"
,
"ramp"
))
a
=
tvm
.
nd
.
array
(
np
.
zeros
(
10
,
dtype
=
dtype
))
def
check
(
f
):
f
(
a
)
np
.
testing
.
assert_equal
(
a
.
asnumpy
(),
np
.
arange
(
a
.
shape
[
0
]))
run_jit
(
fapi
,
check
)
run_jit
(
mod
,
check
)
def
test_stack_vm_cond
():
...
...
@@ -91,14 +83,15 @@ def test_stack_vm_cond():
A
[
i
+
1
]
=
A
[
i
]
+
2
stmt
=
ib
.
get
()
fapi
=
tvm
.
testing
.
MakeAPILegacy
(
stmt
,
"test"
,
[
Ab
],
0
,
True
)
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
Ab
],
stmt
)
.
with_attr
(
"global_symbol"
,
"test"
))
def
check
(
f
):
a
=
tvm
.
nd
.
array
(
np
.
zeros
(
10
,
dtype
=
dtype
))
f
(
a
)
y
=
np
.
arange
(
a
.
shape
[
0
])
*
2
y
[
5
:]
-=
1
np
.
testing
.
assert_equal
(
a
.
asnumpy
(),
y
)
run_jit
(
fapi
,
check
)
run_jit
(
mod
,
check
)
def
test_vm_parallel
():
dtype
=
'int64'
...
...
@@ -110,12 +103,13 @@ def test_vm_parallel():
with
ib
.
for_range
(
0
,
n
,
"i"
,
for_type
=
"parallel"
)
as
i
:
A
[
i
]
=
A
[
i
]
+
1
stmt
=
ib
.
get
()
fapi
=
tvm
.
testing
.
MakeAPILegacy
(
stmt
,
"ramp"
,
[
Ab
],
0
,
True
)
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
Ab
],
stmt
)
.
with_attr
(
"global_symbol"
,
"test"
))
def
check
(
f
):
a
=
tvm
.
nd
.
array
(
np
.
zeros
(
10
,
dtype
=
dtype
))
f
(
a
)
np
.
testing
.
assert_equal
(
a
.
asnumpy
(),
np
.
ones
(
a
.
shape
[
0
]))
run_jit
(
fapi
,
check
)
run_jit
(
mod
,
check
)
if
__name__
==
"__main__"
:
...
...
tests/python/unittest/test_tir_nodes.py
View file @
f1438813
...
...
@@ -277,7 +277,7 @@ def test_prim_func():
assert
func
.
buffer_map
[
func
.
params
[
2
]]
.
same_as
(
b
)
assert
len
(
func
.
buffer_map
)
==
1
f2
=
func
.
with_attr
(
"calling_conv"
,
1
)
f2
=
func
.
with_attr
(
{
"calling_conv"
:
1
,
"tir.noalias"
:
True
}
)
assert
f2
.
attrs
[
"calling_conv"
]
.
value
==
1
assert
func
.
attrs
is
None
...
...
tests/python/unittest/test_tir_pass_storage_flatten.py
View file @
f1438813
...
...
@@ -92,7 +92,9 @@ def test_flatten_double_buffer():
stmt
=
tvm
.
tir
.
ir_pass
.
Simplify
(
stmt
)
assert
isinstance
(
stmt
.
body
.
body
,
tvm
.
tir
.
Allocate
)
assert
stmt
.
body
.
body
.
extents
[
0
]
.
value
==
2
mod
=
tvm
.
testing
.
MakeAPILegacy
(
stmt
,
"db"
,
[
A
.
asobject
(),
C
.
asobject
()],
0
,
True
)
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
A
,
C
],
stmt
)
.
with_attr
(
"global_symbol"
,
"db"
))
f
=
tvm
.
tir
.
transform
.
ThreadSync
(
"shared"
)(
mod
)[
"db"
]
count
=
[
0
]
...
...
tests/python/unittest/test_tir_transform_lower_warp_memory.py
View file @
f1438813
...
...
@@ -43,7 +43,7 @@ def test_lower_warp_memory_local_scope():
mod
=
tvm
.
tir
.
transform
.
Apply
(
lambda
f
:
f
.
with_attr
(
"target"
,
cuda_target
))(
mod
)
fdevice
=
tvm
.
tir
.
transform
.
SplitHostDevice
()(
mod
)[
"f_kernel0"
]
mod
=
tvm
.
IRModule
.
from_expr
(
fdevice
)
fdevice
=
tvm
.
tir
.
transform
.
LowerWarpMemory
()(
mod
)[
"
main
"
]
fdevice
=
tvm
.
tir
.
transform
.
LowerWarpMemory
()(
mod
)[
"
f_kernel0
"
]
assert
(
fdevice
.
body
.
body
.
value
.
value
==
"local"
)
assert
(
fdevice
.
body
.
body
.
body
.
extents
[
0
]
.
value
==
2
)
...
...
tests/python/unittest/test_tir_transform_make_packed_api.py
View file @
f1438813
...
...
@@ -35,11 +35,11 @@ def test_makeapi():
stmt
=
tvm
.
tir
.
ir_pass
.
StorageFlatten
(
stmt
,
{
A
:
Ab
,
B
:
Bb
,
C
:
Cb
},
64
)
num_unpacked_args
=
2
f
=
tvm
.
tir
.
PrimFunc
([
n
,
Ab
,
Bb
,
Cb
],
stmt
)
f
=
f
.
with_attr
(
"global_symbol"
,
"myadd"
)
f
=
f
.
with_attr
(
"target"
,
tvm
.
target
.
create
(
"llvm"
))
mod
=
tvm
.
IRModule
.
from_expr
(
f
)
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
n
,
Ab
,
Bb
,
Cb
],
stmt
)
.
with_attr
({
"global_symbol"
:
"main"
,
"target"
:
tvm
.
target
.
create
(
"llvm"
)
})
)
f
=
tvm
.
tir
.
transform
.
MakePackedAPI
(
num_unpacked_args
)(
mod
)[
"main"
]
assert
(
len
(
f
.
params
)
==
7
)
...
...
tests/python/unittest/test_tir_transform_thread_sync.py
View file @
f1438813
...
...
@@ -39,12 +39,15 @@ def test_thread_storage_sync():
stmt
=
tvm
.
tir
.
ir_pass
.
StorageFlatten
(
stmt
,
{
A
:
Ab
,
A2
:
A2b
},
64
)
cuda_target
=
tvm
.
target
.
create
(
"cuda"
)
mod
=
tvm
.
testing
.
MakeAPILegacy
(
stmt
,
"test"
,
[
Ab
,
A2b
],
0
,
True
)
mod
=
tvm
.
tir
.
transform
.
Apply
(
lambda
f
:
f
.
with_attr
(
"target"
,
cuda_target
))(
mod
)
mod
=
tvm
.
IRModule
.
from_expr
(
tvm
.
tir
.
PrimFunc
([
Ab
,
A2b
],
stmt
)
.
with_attr
({
"global_symbol"
:
"test"
,
"target"
:
cuda_target
}))
fdevice
=
tvm
.
tir
.
transform
.
SplitHostDevice
()(
mod
)[
"test_kernel0"
]
mod
=
tvm
.
IRModule
.
from_expr
(
fdevice
)
cuda_target
=
tvm
.
target
.
create
(
"cuda"
)
f
=
tvm
.
tir
.
transform
.
ThreadSync
(
"shared"
)(
mod
)[
"
main
"
]
f
=
tvm
.
tir
.
transform
.
ThreadSync
(
"shared"
)(
mod
)[
"
test_kernel0
"
]
body_list
=
tvm
.
tir
.
stmt_list
(
f
.
body
.
body
.
body
.
body
)
assert
(
body_list
[
1
]
.
value
.
name
==
"tvm_storage_sync"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment