Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
7f82912b
Commit
7f82912b
authored
Jan 15, 2017
by
Tianqi Chen
Committed by
GitHub
Jan 15, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Basic storage flatten (#13)
parent
0992873a
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
198 additions
and
4 deletions
+198
-4
python/tvm/_ctypes/_api.py
+1
-1
python/tvm/function.py
+2
-1
src/c_api/c_api_pass.cc
+1
-0
src/lang/buffer.cc
+1
-1
src/pass/ir_mutator.cc
+1
-1
src/pass/storage_flatten.cc
+168
-0
tests/python/test_pass_storage_flatten.py
+24
-0
No files found.
python/tvm/_ctypes/_api.py
View file @
7f82912b
...
...
@@ -225,7 +225,7 @@ def _make_function(handle, name):
"""TVM function"""
cargs
=
[]
for
x
in
args
:
if
isinstance
(
x
,
(
list
,
tuple
,
SliceBase
)):
if
isinstance
(
x
,
(
list
,
tuple
,
dict
,
SliceBase
)):
cargs
.
append
(
convert
(
x
))
else
:
cargs
.
append
(
x
)
...
...
python/tvm/function.py
View file @
7f82912b
...
...
@@ -133,7 +133,8 @@ def compute(shape, fcompute, name="compute"):
def
Buffer
(
shape
,
dtype
=
None
,
name
=
"buffer"
,
ptr
=
None
,
name
=
"buffer"
,
ptr
=
None
,
strides
=
None
):
"""Create a new buffer
...
...
src/c_api/c_api_pass.cc
View file @
7f82912b
...
...
@@ -36,6 +36,7 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1
(
VerifySSA
);
REGISTER_PASS4
(
Inline
);
REGISTER_PASS2
(
ScheduleOps
);
REGISTER_PASS2
(
StorageFlatten
);
}
// namespace ir
}
// namespace tvm
src/lang/buffer.cc
View file @
7f82912b
...
...
@@ -51,7 +51,7 @@ Expr Buffer::MakeLoad(Array<Expr> index) const {
Stmt
Buffer
::
MakeStore
(
Array
<
Expr
>
index
,
Expr
value
)
const
{
const
BufferNode
*
n
=
operator
->
();
CHECK_EQ
(
value
.
type
(),
n
->
dtype
);
return
ir
::
Store
::
make
(
n
->
ptr
,
BufferOffset
(
n
,
index
),
value
);
return
ir
::
Store
::
make
(
n
->
ptr
,
value
,
BufferOffset
(
n
,
index
)
);
}
Buffer
BufferNode
::
make
(
std
::
string
name
,
...
...
src/pass/ir_mutator.cc
View file @
7f82912b
...
...
@@ -83,7 +83,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
body
.
same_as
(
op
->
body
))
{
return
s
;
}
else
{
return
AttrStmt
::
make
(
op
->
node
,
op
->
type_key
,
op
->
value
,
op
->
body
);
return
AttrStmt
::
make
(
op
->
node
,
op
->
type_key
,
value
,
body
);
}
});
...
...
src/pass/storage_flatten.cc
0 → 100644
View file @
7f82912b
/*!
* Copyright (c) 2016 by Contributors
* \file storage_flatten.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
namespace
tvm
{
namespace
ir
{
// key of function buffer
struct
TensorKey
{
FunctionRef
f
;
int
value_index
;
inline
bool
operator
==
(
const
TensorKey
&
other
)
const
{
return
f
==
other
.
f
&&
value_index
==
other
.
value_index
;
}
inline
std
::
string
GetName
()
const
{
if
(
f
->
num_outputs
()
==
1
)
return
f
->
func_name
();
std
::
ostringstream
os
;
os
<<
f
->
func_name
()
<<
".v"
<<
value_index
;
return
os
.
str
();
}
};
}
// namespace ir
}
// namespace tvm
namespace
std
{
template
<>
struct
hash
<::
tvm
::
ir
::
TensorKey
>
{
std
::
size_t
operator
()(
const
::
tvm
::
ir
::
TensorKey
&
k
)
const
{
size_t
lhs
=
k
.
f
.
hash
();
size_t
rhs
=
static_cast
<
size_t
>
(
k
.
value_index
);
lhs
^=
rhs
+
0x9e3779b9
+
(
lhs
<<
6
)
+
(
lhs
>>
2
);
return
lhs
;
}
};
}
// namespace std
namespace
tvm
{
namespace
ir
{
using
Halide
::
Internal
::
Region
;
// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class
StorageFlattener
:
public
IRMutator
{
public
:
explicit
StorageFlattener
(
Map
<
Tensor
,
Buffer
>
extern_buffer
)
{
for
(
auto
kv
:
extern_buffer
)
{
BufferEntry
e
;
e
.
buffer
=
kv
.
second
;
e
.
external
=
true
;
buf_map_
[
TensorKey
{
kv
.
first
->
op
,
kv
.
first
->
value_index
}]
=
e
;
}
}
Expr
Mutate
(
Expr
expr
)
final
{
expr
=
IRMutator
::
Mutate
(
expr
);
const
Call
*
op
=
expr
.
as
<
Call
>
();
if
(
op
!=
nullptr
&&
op
->
call_type
==
Call
::
Halide
)
{
TensorKey
key
{
op
->
func
,
op
->
value_index
};
auto
it
=
buf_map_
.
find
(
key
);
CHECK
(
it
!=
buf_map_
.
end
())
<<
"Cannot find allocated buffer for "
<<
key
.
f
;
const
BufferEntry
&
e
=
it
->
second
;
CHECK
(
!
e
.
released
)
<<
"Read a buffer that is already out of scope"
;
return
e
.
buffer
.
MakeLoad
(
e
.
RelIndex
(
op
->
args
));
}
else
{
return
expr
;
}
}
Stmt
Mutate
(
Stmt
stmt
)
final
{
const
Realize
*
realize
=
stmt
.
as
<
Realize
>
();
if
(
realize
!=
nullptr
)
{
return
HandleRealize
(
realize
);
}
else
if
(
stmt
.
as
<
Provide
>
())
{
return
HandleProvide
(
stmt
);
}
else
{
return
IRMutator
::
Mutate
(
stmt
);
}
}
private
:
// The buffer entry in the flatten map
struct
BufferEntry
{
// the buffer of storage
Buffer
buffer
;
// the bounds of realization, can be null
Region
bounds
;
// Whether the buffer is external
bool
external
{
false
};
// Whether we are out of allocation bounds and buffer get released.
bool
released
{
false
};
// TODO(tqchen) allow permutation and inference of index dimension.
// relative index
inline
Array
<
Expr
>
RelIndex
(
Array
<
Expr
>
args
)
const
{
if
(
bounds
.
size
()
!=
0
)
{
Array
<
Expr
>
index
;
CHECK_EQ
(
bounds
.
size
(),
args
.
size
());
for
(
size_t
i
=
0
;
i
<
bounds
.
size
();
++
i
)
{
index
.
push_back
(
args
[
i
]
-
bounds
[
i
]
->
min
);
}
return
index
;
}
else
{
return
args
;
}
}
};
// The buffer assignment map
std
::
unordered_map
<
TensorKey
,
BufferEntry
>
buf_map_
;
Stmt
HandleRealize
(
const
Realize
*
op
)
{
TensorKey
key
{
op
->
func
,
op
->
value_index
};
if
(
buf_map_
.
count
(
key
))
{
CHECK
(
buf_map_
.
at
(
key
).
external
);
return
this
->
Mutate
(
op
->
body
);
}
else
{
// create a buffer entry
// TODO(tqchen) allow permutation and inference of index dimension.
BufferEntry
e
;
e
.
bounds
=
op
->
bounds
;
Array
<
Expr
>
shape
;
for
(
auto
r
:
e
.
bounds
)
{
shape
.
push_back
(
r
->
extent
);
}
e
.
buffer
=
Buffer
(
shape
,
op
->
type
,
key
.
GetName
());
buf_map_
[
key
]
=
e
;
Stmt
body
=
this
->
Mutate
(
op
->
body
);
buf_map_
[
key
].
released
=
true
;
return
Allocate
::
make
(
e
.
buffer
->
ptr
,
e
.
buffer
->
dtype
,
e
.
buffer
->
shape
,
make_const
(
Bool
(
e
.
buffer
->
dtype
.
lanes
()),
true
),
body
);
}
}
Stmt
HandleProvide
(
Stmt
stmt
)
{
stmt
=
IRMutator
::
Mutate
(
stmt
);
const
Provide
*
op
=
stmt
.
as
<
Provide
>
();
TensorKey
key
{
op
->
func
,
op
->
value_index
};
auto
it
=
buf_map_
.
find
(
key
);
CHECK
(
it
!=
buf_map_
.
end
())
<<
"Cannot find allocated buffer for "
<<
key
.
f
;
const
BufferEntry
&
e
=
it
->
second
;
CHECK
(
!
e
.
released
)
<<
"Read a buffer that is already out of scope"
;
return
e
.
buffer
.
MakeStore
(
e
.
RelIndex
(
op
->
args
),
op
->
value
);
}
};
Stmt
StorageFlatten
(
Stmt
stmt
,
Map
<
Tensor
,
Buffer
>
extern_buffer
)
{
stmt
=
StorageFlattener
(
extern_buffer
).
Mutate
(
stmt
);
return
stmt
;
}
}
// namespace ir
}
// namespace tvm
tests/python/test_pass_storage_flatten.py
0 → 100644
View file @
7f82912b
import
tvm
def
test_flatten2
():
m
=
tvm
.
Var
(
'm'
)
l
=
tvm
.
Var
(
'l'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
A1
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A
[
i
,
j
],
name
=
'A1'
)
A2
=
tvm
.
compute
((
m
,
l
),
lambda
i
,
j
:
A1
[
i
,
j
]
+
3
,
name
=
'A2'
)
s
=
tvm
.
Schedule
(
A2
.
op
)
xo
,
xi
=
s
[
A2
]
.
split
(
A2
.
op
.
axis
[
0
],
8
)
s
[
A1
]
.
compute_at
(
s
[
A2
],
xo
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
assert
isinstance
(
bounds
,
tvm
.
collections
.
Map
)
stmt
=
tvm
.
ir_pass
.
ScheduleOps
(
s
,
bounds
)
print
(
stmt
)
Ab
=
tvm
.
Buffer
(
A
.
shape
,
A
.
dtype
,
name
=
'A'
)
A2b
=
tvm
.
Buffer
(
A2
.
shape
,
A2
.
dtype
,
name
=
'A2'
)
stmt
=
tvm
.
ir_pass
.
StorageFlatten
(
stmt
,
{
A
:
Ab
,
A2
:
A2b
})
print
(
stmt
)
if
__name__
==
"__main__"
:
test_flatten2
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment