Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
062bb853
Commit
062bb853
authored
Oct 26, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add in Array, fix most of IR
parent
622cee7a
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
160 additions
and
11 deletions
+160
-11
HalideIR
+1
-1
python/tvm/__init__.py
+1
-0
python/tvm/_ctypes/_api.py
+31
-3
python/tvm/domain.py
+16
-0
python/tvm/expr.py
+4
-0
python/tvm/function.py
+17
-2
python/tvm/stmt.py
+4
-0
src/c_api/c_api_function.cc
+37
-0
src/c_api/c_api_ir.cc
+26
-3
src/c_api/c_api_registry.h
+4
-2
tests/python/test_basic.py
+19
-0
No files found.
HalideIR
@
9070ac36
Subproject commit
872099363b9f16a6cd4a4e8e46b9bd8dd1b861e9
Subproject commit
9070ac3697931ef5aeb8c373c23b2e8a2fec4627
python/tvm/__init__.py
View file @
062bb853
...
@@ -6,3 +6,4 @@ from ._ctypes._api import register_node
...
@@ -6,3 +6,4 @@ from ._ctypes._api import register_node
from
.
import
expr
from
.
import
expr
from
.
import
stmt
from
.
import
stmt
from
.
import
make
from
.
import
make
from
.
import
domain
python/tvm/_ctypes/_api.py
View file @
062bb853
...
@@ -5,7 +5,7 @@ from __future__ import absolute_import as _abs
...
@@ -5,7 +5,7 @@ from __future__ import absolute_import as _abs
import
ctypes
import
ctypes
import
sys
import
sys
from
numbers
import
Number
as
Number
from
numbers
import
Number
,
Integral
from
.._base
import
_LIB
from
.._base
import
_LIB
from
.._base
import
c_str
,
py_str
,
string_types
from
.._base
import
c_str
,
py_str
,
string_types
...
@@ -93,6 +93,27 @@ class NodeBase(object):
...
@@ -93,6 +93,27 @@ class NodeBase(object):
names
.
append
(
py_str
(
plist
[
i
]))
names
.
append
(
py_str
(
plist
[
i
]))
return
names
return
names
def
const
(
value
,
dtype
=
None
):
"""construct a constant"""
if
dtype
is
None
:
if
isinstance
(
value
,
Integral
):
dtype
=
'int32'
else
:
dtype
=
'float32'
return
_function_internal
.
_const
(
value
,
dtype
)
def
convert
(
value
):
"""Convert a value to expression."""
if
isinstance
(
value
,
Number
):
return
const
(
value
)
elif
isinstance
(
value
,
list
):
value
=
[
convert
(
x
)
for
x
in
value
]
return
_function_internal
.
_Array
(
*
value
)
else
:
if
not
isinstance
(
value
,
NodeBase
):
raise
ValueError
(
"don't know how to handle type
%
s"
%
type
(
value
))
def
_push_arg
(
arg
):
def
_push_arg
(
arg
):
a
=
ArgVariant
()
a
=
ArgVariant
()
...
@@ -147,9 +168,16 @@ def _make_function(handle, name):
...
@@ -147,9 +168,16 @@ def _make_function(handle, name):
doc_str
=
doc_str
%
(
desc
,
param_str
)
doc_str
=
doc_str
%
(
desc
,
param_str
)
arg_names
=
[
py_str
(
arg_names
[
i
])
for
i
in
range
(
num_args
.
value
)]
arg_names
=
[
py_str
(
arg_names
[
i
])
for
i
in
range
(
num_args
.
value
)]
def
func
(
*
args
,
**
kwargs
):
def
func
(
*
args
):
"""TVM function"""
"""TVM function"""
for
arg
in
args
:
cargs
=
[]
for
x
in
args
:
if
isinstance
(
x
,
list
):
cargs
.
append
(
convert
(
x
))
else
:
cargs
.
append
(
x
)
for
arg
in
cargs
:
_push_arg
(
arg
)
_push_arg
(
arg
)
ret_val
=
ArgVariant
()
ret_val
=
ArgVariant
()
ret_typeid
=
ctypes
.
c_int
()
ret_typeid
=
ctypes
.
c_int
()
...
...
python/tvm/domain.py
0 → 100644
View file @
062bb853
from
__future__
import
absolute_import
as
_abs
from
._ctypes._api
import
NodeBase
,
register_node
from
.
import
_function_internal
@register_node
class
Array
(
NodeBase
):
def
__getitem__
(
self
,
i
):
if
i
>=
len
(
self
):
raise
IndexError
(
"array index out ot range"
)
return
_function_internal
.
_ArrayGetItem
(
self
,
i
)
def
__len__
(
self
):
return
_function_internal
.
_ArraySize
(
self
)
def
__repr__
(
self
):
return
'['
+
(
','
.
join
(
str
(
x
)
for
x
in
self
))
+
']'
python/tvm/expr.py
View file @
062bb853
...
@@ -52,6 +52,10 @@ class CmpExpr(Expr):
...
@@ -52,6 +52,10 @@ class CmpExpr(Expr):
class
LogicalExpr
(
Expr
):
class
LogicalExpr
(
Expr
):
pass
pass
@register_node
(
"Variable"
)
class
Var
(
Expr
):
pass
@register_node
@register_node
class
FloatImm
(
ConstExpr
):
class
FloatImm
(
ConstExpr
):
pass
pass
...
...
python/tvm/function.py
View file @
062bb853
...
@@ -8,6 +8,7 @@ int32 = "int32"
...
@@ -8,6 +8,7 @@ int32 = "int32"
float32
=
"float32"
float32
=
"float32"
def
const
(
value
,
dtype
=
None
):
def
const
(
value
,
dtype
=
None
):
"""construct a constant"""
if
dtype
is
None
:
if
dtype
is
None
:
if
isinstance
(
value
,
_Integral
):
if
isinstance
(
value
,
_Integral
):
dtype
=
'int32'
dtype
=
'int32'
...
@@ -16,12 +17,26 @@ def const(value, dtype=None):
...
@@ -16,12 +17,26 @@ def const(value, dtype=None):
return
_function_internal
.
_const
(
value
,
dtype
)
return
_function_internal
.
_const
(
value
,
dtype
)
def
_symbol
(
value
):
def
Var
(
name
=
"tindex"
,
dtype
=
int32
):
"""Create a new variable with specified name and dtype
Parameters
----------
name : str
The name
dtype : int
The data type
"""
return
_function_internal
.
_Var
(
name
,
dtype
)
def
convert
(
value
):
"""Convert a value to expression."""
"""Convert a value to expression."""
if
isinstance
(
value
,
_Number
):
if
isinstance
(
value
,
_Number
):
return
const
(
value
)
return
const
(
value
)
elif
isinstance
(
value
,
list
):
elif
isinstance
(
value
,
list
):
value
=
[
_symbol
(
x
)
for
x
in
value
]
value
=
[
convert
(
x
)
for
x
in
value
]
return
_function_internal
.
_Array
(
*
value
)
return
_function_internal
.
_Array
(
*
value
)
else
:
else
:
return
value
return
value
...
...
python/tvm/stmt.py
View file @
062bb853
...
@@ -21,6 +21,10 @@ class ProducerConsumer(Stmt):
...
@@ -21,6 +21,10 @@ class ProducerConsumer(Stmt):
@register_node
@register_node
class
For
(
Stmt
):
class
For
(
Stmt
):
Serial
=
0
Parallel
=
1
Vectorized
=
2
Unrolled
=
3
pass
pass
@register_node
@register_node
...
...
src/c_api/c_api_function.cc
View file @
062bb853
...
@@ -40,9 +40,46 @@ TVM_REGISTER_API(format_str)
...
@@ -40,9 +40,46 @@ TVM_REGISTER_API(format_str)
os
<<
args
.
at
(
0
).
operator
Expr
();
os
<<
args
.
at
(
0
).
operator
Expr
();
}
else
if
(
dynamic_cast
<
const
BaseStmtNode
*>
(
sptr
.
get
()))
{
}
else
if
(
dynamic_cast
<
const
BaseStmtNode
*>
(
sptr
.
get
()))
{
os
<<
args
.
at
(
0
).
operator
Stmt
();
os
<<
args
.
at
(
0
).
operator
Stmt
();
}
else
{
LOG
(
FATAL
)
<<
"don't know how to print input NodeBaseType"
;
}
}
*
ret
=
os
.
str
();
*
ret
=
os
.
str
();
})
})
.
add_argument
(
"expr"
,
"Node"
,
"expression to be printed"
);
.
add_argument
(
"expr"
,
"Node"
,
"expression to be printed"
);
TVM_REGISTER_API
(
_Array
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
std
::
vector
<
std
::
shared_ptr
<
Node
>
>
data
;
for
(
size_t
i
=
0
;
i
<
args
.
size
();
++
i
)
{
CHECK
(
args
.
at
(
i
).
type_id
==
kNodeHandle
);
data
.
push_back
(
args
.
at
(
i
).
sptr
);
}
auto
node
=
std
::
make_shared
<
ArrayNode
>
();
node
->
data
=
std
::
move
(
data
);
ret
->
type_id
=
kNodeHandle
;
ret
->
sptr
=
node
;
});
TVM_REGISTER_API
(
_ArrayGetItem
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
CHECK
(
args
.
at
(
0
).
type_id
==
kNodeHandle
);
int64_t
i
=
args
.
at
(
1
);
auto
&
sptr
=
args
.
at
(
0
).
sptr
;
CHECK
(
sptr
->
is_type
<
ArrayNode
>
());
auto
*
n
=
static_cast
<
const
ArrayNode
*>
(
sptr
.
get
());
CHECK_LT
(
static_cast
<
size_t
>
(
i
),
n
->
data
.
size
())
<<
"out of bound of array"
;
ret
->
sptr
=
n
->
data
[
i
];
ret
->
type_id
=
kNodeHandle
;
});
TVM_REGISTER_API
(
_ArraySize
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
CHECK
(
args
.
at
(
0
).
type_id
==
kNodeHandle
);
auto
&
sptr
=
args
.
at
(
0
).
sptr
;
CHECK
(
sptr
->
is_type
<
ArrayNode
>
());
*
ret
=
static_cast
<
int64_t
>
(
static_cast
<
const
ArrayNode
*>
(
sptr
.
get
())
->
data
.
size
());
});
}
// namespace tvm
}
// namespace tvm
src/c_api/c_api_ir.cc
View file @
062bb853
...
@@ -14,6 +14,30 @@ using namespace Halide::Internal;
...
@@ -14,6 +14,30 @@ using namespace Halide::Internal;
using
ArgStack
=
const
std
::
vector
<
APIVariantValue
>
;
using
ArgStack
=
const
std
::
vector
<
APIVariantValue
>
;
using
RetValue
=
APIVariantValue
;
using
RetValue
=
APIVariantValue
;
TVM_REGISTER_API
(
_Var
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
Variable
::
make
(
args
.
at
(
1
),
args
.
at
(
0
));
});
TVM_REGISTER_API
(
_make_For
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
For
::
make
(
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
),
static_cast
<
ForType
>
(
args
.
at
(
3
).
operator
int
()),
static_cast
<
Halide
::
DeviceAPI
>
(
args
.
at
(
4
).
operator
int
()),
args
.
at
(
5
));
});
TVM_REGISTER_API
(
_make_Allocate
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
Allocate
::
make
(
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
),
args
.
at
(
3
),
args
.
at
(
4
));
});
// make from two arguments
// make from two arguments
#define REGISTER_MAKE1(Node) \
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \
TVM_REGISTER_API(_make_## Node) \
...
@@ -67,13 +91,12 @@ REGISTER_MAKE3(Select);
...
@@ -67,13 +91,12 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3
(
Ramp
);
REGISTER_MAKE3
(
Ramp
);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE3
(
Let
);
REGISTER_MAKE3
(
Let
);
// TODO(tqchen) Call;
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE2
(
AssertStmt
);
REGISTER_MAKE2
(
AssertStmt
);
REGISTER_MAKE3
(
ProducerConsumer
);
REGISTER_MAKE3
(
ProducerConsumer
);
// TODO(tqchen) For;
REGISTER_MAKE3
(
Store
);
REGISTER_MAKE3
(
Store
);
// TODO(tqchen) Provide;
REGISTER_MAKE3
(
Provide
);
// TODO(tqchen) Allocate;
REGISTER_MAKE1
(
Free
);
REGISTER_MAKE1
(
Free
);
// TODO(tqchen) Realize;
// TODO(tqchen) Realize;
REGISTER_MAKE2
(
Block
);
REGISTER_MAKE2
(
Block
);
...
...
src/c_api/c_api_registry.h
View file @
062bb853
...
@@ -96,8 +96,10 @@ struct APIVariantValue {
...
@@ -96,8 +96,10 @@ struct APIVariantValue {
}
}
inline
operator
Expr
()
const
{
inline
operator
Expr
()
const
{
if
(
type_id
==
kNull
)
return
Expr
();
if
(
type_id
==
kNull
)
return
Expr
();
if
(
type_id
==
kLong
)
return
Expr
(
operator
int64_t
());
if
(
type_id
==
kLong
)
return
Expr
(
operator
int
());
if
(
type_id
==
kDouble
)
return
Expr
(
operator
double
());
if
(
type_id
==
kDouble
)
{
return
Expr
(
static_cast
<
float
>
(
operator
double
()));
}
CHECK_EQ
(
type_id
,
kNodeHandle
);
CHECK_EQ
(
type_id
,
kNodeHandle
);
return
Expr
(
sptr
);
return
Expr
(
sptr
);
}
}
...
...
tests/python/test_basic.py
View file @
062bb853
...
@@ -19,7 +19,26 @@ def test_ir():
...
@@ -19,7 +19,26 @@ def test_ir():
assert
isinstance
(
stmt
,
tvm
.
stmt
.
Evaluate
)
assert
isinstance
(
stmt
,
tvm
.
stmt
.
Evaluate
)
print
tvm
.
format_str
(
stmt
)
print
tvm
.
format_str
(
stmt
)
def
test_basic
():
a
=
tvm
.
Var
(
'a'
)
b
=
tvm
.
Var
(
'b'
)
c
=
a
+
b
assert
tvm
.
format_str
(
c
)
==
'(
%
s +
%
s)'
%
(
a
.
name
,
b
.
name
)
def
test_array
():
a
=
tvm
.
convert
([
1
,
2
,
3
])
def
test_stmt
():
print
tvm
.
make
.
Provide
(
'a'
,
[
1
,
2
,
3
],
[
1
,
2
,
3
])
print
tvm
.
make
.
For
(
'a'
,
0
,
1
,
tvm
.
stmt
.
For
.
Serial
,
0
,
tvm
.
make
.
Evaluate
(
0
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_const
()
test_const
()
test_make
()
test_make
()
test_ir
()
test_ir
()
test_basic
()
test_stmt
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment