Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
622cee7a
Commit
622cee7a
authored
Oct 26, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add most of IR constructors
parent
a41d644a
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
325 additions
and
31 deletions
+325
-31
HalideIR
+1
-1
include/tvm/base.h
+1
-1
include/tvm/expr.h
+2
-2
python/tvm/__init__.py
+2
-0
python/tvm/_ctypes/_api.py
+22
-8
python/tvm/expr.py
+125
-13
python/tvm/function.py
+1
-0
python/tvm/make.py
+1
-0
python/tvm/stmt.py
+56
-0
src/c_api/c_api_common.h
+4
-3
src/c_api/c_api_ir.cc
+83
-0
src/c_api/c_api_registry.h
+10
-2
tests/python/test_basic.py
+17
-1
No files found.
HalideIR
@
87209936
Subproject commit
2a1001108b9112c4e594c456ffd364b57db10b6b
Subproject commit
872099363b9f16a6cd4a4e8e46b9bd8dd1b861e9
include/tvm/base.h
View file @
622cee7a
...
@@ -8,12 +8,12 @@
...
@@ -8,12 +8,12 @@
#include <dmlc/logging.h>
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <dmlc/registry.h>
#include <tvm/node.h>
#include <string>
#include <string>
#include <memory>
#include <memory>
#include <functional>
#include <functional>
#include <typeinfo>
#include <typeinfo>
#include <type_traits>
#include <type_traits>
#include <tvm/node.h>
namespace
tvm
{
namespace
tvm
{
...
...
include/tvm/expr.h
View file @
622cee7a
...
@@ -6,8 +6,8 @@
...
@@ -6,8 +6,8 @@
#ifndef TVM_EXPR_H_
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_
#define TVM_EXPR_H_
#include <type_traits>
#include <ir/Expr.h>
#include <ir/Expr.h>
#include <type_traits>
#include "./base.h"
#include "./base.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -15,5 +15,5 @@ namespace tvm {
...
@@ -15,5 +15,5 @@ namespace tvm {
using
Halide
::
Type
;
using
Halide
::
Type
;
using
Halide
::
Expr
;
using
Halide
::
Expr
;
}
// namespace
std
}
// namespace
tvm
#endif // TVM_EXPR_H_
#endif // TVM_EXPR_H_
python/tvm/__init__.py
View file @
622cee7a
...
@@ -4,3 +4,5 @@ from __future__ import absolute_import as _abs
...
@@ -4,3 +4,5 @@ from __future__ import absolute_import as _abs
from
.function
import
*
from
.function
import
*
from
._ctypes._api
import
register_node
from
._ctypes._api
import
register_node
from
.
import
expr
from
.
import
expr
from
.
import
stmt
from
.
import
make
python/tvm/_ctypes/_api.py
View file @
622cee7a
...
@@ -162,19 +162,23 @@ def _make_function(handle, name):
...
@@ -162,19 +162,23 @@ def _make_function(handle, name):
return
func
return
func
def
register_node
(
type_key
):
def
register_node
(
type_key
=
None
):
"""register node type
"""register node type
Parameters
Parameters
----------
----------
type_key : str
type_key : str
or cls
The type key of the node
The type key of the node
"""
"""
def
register
(
cls
):
if
isinstance
(
type_key
,
str
):
NODE_TYPE
[
type_key
]
=
cls
def
register
(
cls
):
NODE_TYPE
[
type_key
]
=
cls
return
cls
return
register
else
:
cls
=
type_key
NODE_TYPE
[
cls
.
__name__
]
=
cls
return
cls
return
cls
return
register
def
_init_function_module
(
root_namespace
):
def
_init_function_module
(
root_namespace
):
"""List and add all the functions to current module."""
"""List and add all the functions to current module."""
...
@@ -189,11 +193,21 @@ def _init_function_module(root_namespace):
...
@@ -189,11 +193,21 @@ def _init_function_module(root_namespace):
module_obj
=
sys
.
modules
[
"
%
s.function"
%
root_namespace
]
module_obj
=
sys
.
modules
[
"
%
s.function"
%
root_namespace
]
module_internal
=
sys
.
modules
[
"
%
s._function_internal"
%
root_namespace
]
module_internal
=
sys
.
modules
[
"
%
s._function_internal"
%
root_namespace
]
module_make
=
sys
.
modules
[
"
%
s.make"
%
root_namespace
]
for
name
in
op_names
:
for
name
in
op_names
:
hdl
=
FunctionHandle
()
hdl
=
FunctionHandle
()
check_call
(
_LIB
.
TVMGetFunctionHandle
(
c_str
(
name
),
ctypes
.
byref
(
hdl
)))
check_call
(
_LIB
.
TVMGetFunctionHandle
(
c_str
(
name
),
ctypes
.
byref
(
hdl
)))
function
=
_make_function
(
hdl
,
name
)
if
name
.
startswith
(
"_make_"
):
if
function
.
__name__
.
startswith
(
'_'
):
fname
=
name
[
6
:]
else
:
fname
=
name
function
=
_make_function
(
hdl
,
fname
)
if
name
.
startswith
(
"_make_"
):
setattr
(
module_make
,
function
.
__name__
,
function
)
elif
function
.
__name__
.
startswith
(
'_'
):
setattr
(
module_internal
,
function
.
__name__
,
function
)
setattr
(
module_internal
,
function
.
__name__
,
function
)
else
:
else
:
setattr
(
module_obj
,
function
.
__name__
,
function
)
setattr
(
module_obj
,
function
.
__name__
,
function
)
python/tvm/expr.py
View file @
622cee7a
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
._ctypes._api
import
NodeBase
,
register_node
from
._ctypes._api
import
NodeBase
,
register_node
from
.
import
function
as
_func
from
.
import
function
as
_func
from
.
import
make
as
_make
class
Expr
(
NodeBase
):
class
Expr
(
NodeBase
):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
_func
.
format_str
(
self
)
return
_func
.
format_str
(
self
)
def
__add__
(
self
,
other
):
def
__add__
(
self
,
other
):
return
binary_op
(
'+'
,
self
,
other
)
return
_make
.
Add
(
self
,
other
)
def
__radd__
(
self
,
other
):
def
__radd__
(
self
,
other
):
return
self
.
__add__
(
other
)
return
self
.
__add__
(
other
)
def
__sub__
(
self
,
other
):
def
__sub__
(
self
,
other
):
return
binary_op
(
'-'
,
self
,
other
)
return
_make
.
Sub
(
self
,
other
)
def
__rsub__
(
self
,
other
):
def
__rsub__
(
self
,
other
):
return
binary_op
(
'-'
,
other
,
self
)
return
_make
.
Sub
(
other
,
self
)
def
__mul__
(
self
,
other
):
def
__mul__
(
self
,
other
):
return
binary_op
(
'*'
,
self
,
other
)
return
_make
.
Mul
(
self
,
other
)
def
__rmul__
(
self
,
other
):
def
__rmul__
(
self
,
other
):
return
binary_op
(
'*'
,
other
,
self
)
return
_make
.
Mul
(
other
,
self
)
def
__div__
(
self
,
other
):
def
__div__
(
self
,
other
):
return
binary_op
(
'/'
,
self
,
other
)
return
_make
.
Div
(
self
,
other
)
def
__rdiv__
(
self
,
other
):
def
__rdiv__
(
self
,
other
):
return
binary_op
(
'/'
,
other
,
self
)
return
_make
.
Div
(
other
,
self
)
def
__truediv__
(
self
,
other
):
def
__truediv__
(
self
,
other
):
return
self
.
__div__
(
other
)
return
self
.
__div__
(
other
)
...
@@ -39,15 +40,126 @@ class Expr(NodeBase):
...
@@ -39,15 +40,126 @@ class Expr(NodeBase):
def
__neg__
(
self
):
def
__neg__
(
self
):
return
self
.
__mul__
(
-
1
)
return
self
.
__mul__
(
-
1
)
class
ConstExpr
(
Expr
):
pass
class
BinaryOpExpr
(
Expr
):
pass
class
CmpExpr
(
Expr
):
pass
class
LogicalExpr
(
Expr
):
pass
@register_node
class
FloatImm
(
ConstExpr
):
pass
@register_node
class
IntImm
(
ConstExpr
):
pass
@register_node
class
UIntImm
(
ConstExpr
):
pass
@register_node
class
StringImm
(
ConstExpr
):
pass
@register_node
class
Cast
(
Expr
):
pass
@register_node
class
Variable
(
Expr
):
pass
@register_node
class
Add
(
BinaryOpExpr
):
pass
@register_node
class
Sub
(
BinaryOpExpr
):
pass
@register_node
class
Mul
(
BinaryOpExpr
):
pass
@register_node
class
Div
(
BinaryOpExpr
):
pass
@register_node
class
Mod
(
BinaryOpExpr
):
pass
@register_node
class
Min
(
BinaryOpExpr
):
pass
@register_node
class
Max
(
BinaryOpExpr
):
pass
@register_node
class
EQ
(
CmpExpr
):
pass
@register_node
class
NE
(
CmpExpr
):
pass
@register_node
class
LT
(
CmpExpr
):
pass
@register_node
class
LE
(
CmpExpr
):
pass
@register_node
class
GT
(
CmpExpr
):
pass
@register_node
class
GE
(
CmpExpr
):
pass
@register_node
class
And
(
LogicalExpr
):
pass
@register_node
class
Or
(
LogicalExpr
):
pass
@register_node
class
Not
(
LogicalExpr
):
pass
@register_node
class
Select
(
Expr
):
pass
@register_node
class
Load
(
Expr
):
pass
@register_node
class
Ramp
(
Expr
):
pass
@register_node
(
"IntImm"
)
@register_node
class
IntImm
(
Expr
):
class
Broadcast
(
Expr
):
pass
pass
@register_node
(
"UIntImm"
)
@register_node
class
UIntImm
(
Expr
):
class
Call
(
Expr
):
pass
pass
@register_node
(
"FloatImm"
)
@register_node
class
FloatImm
(
Expr
):
class
Let
(
Expr
):
pass
pass
python/tvm/function.py
View file @
622cee7a
...
@@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs
...
@@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs
from
numbers
import
Number
as
_Number
,
Integral
as
_Integral
from
numbers
import
Number
as
_Number
,
Integral
as
_Integral
from
._ctypes._api
import
_init_function_module
from
._ctypes._api
import
_init_function_module
from
.
import
_function_internal
from
.
import
_function_internal
from
.
import
make
as
_make
int32
=
"int32"
int32
=
"int32"
float32
=
"float32"
float32
=
"float32"
...
...
python/tvm/make.py
0 → 100644
View file @
622cee7a
"""namespace of IR node builder make function"""
python/tvm/stmt.py
0 → 100644
View file @
622cee7a
from
__future__
import
absolute_import
as
_abs
from
._ctypes._api
import
NodeBase
,
register_node
from
.
import
function
as
_func
from
.
import
make
as
_make
class
Stmt
(
NodeBase
):
def
__repr__
(
self
):
return
_func
.
format_str
(
self
)
@register_node
class
LetStmt
(
Stmt
):
pass
@register_node
class
AssertStmt
(
Stmt
):
pass
@register_node
class
ProducerConsumer
(
Stmt
):
pass
@register_node
class
For
(
Stmt
):
pass
@register_node
class
Store
(
Stmt
):
pass
@register_node
class
Provide
(
Stmt
):
pass
@register_node
class
Allocate
(
Stmt
):
pass
@register_node
class
Free
(
Stmt
):
pass
@register_node
class
Realize
(
Stmt
):
pass
@register_node
class
Block
(
Stmt
):
pass
@register_node
class
IfThenElse
(
Stmt
):
pass
@register_node
class
Evaluate
(
Stmt
):
pass
src/c_api/c_api_common.h
View file @
622cee7a
...
@@ -12,19 +12,20 @@
...
@@ -12,19 +12,20 @@
#include <tvm/c_api.h>
#include <tvm/c_api.h>
#include <vector>
#include <vector>
#include <string>
#include <string>
#include <exception>
#include "./c_api_registry.h"
#include "./c_api_registry.h"
/*! \brief macro to guard beginning and end section of all functions */
/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
#define API_BEGIN() try {
/*! \brief every function starts with API_BEGIN();
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
and finishes with API_END() or API_END_HANDLE_ERROR */
#define API_END() } catch(
dmlc::E
rror &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
#define API_END() } catch(
std::runtime_e
rror &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
/*!
/*!
* \brief every function starts with API_BEGIN();
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR
* and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens.
* The finally clause contains procedure to cleanup states when an error happens.
*/
*/
#define API_END_HANDLE_ERROR(Finalize) } catch(
dmlc::E
rror &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
#define API_END_HANDLE_ERROR(Finalize) } catch(
std::runtime_e
rror &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
void
TVMAPISetLastError
(
const
char
*
msg
);
void
TVMAPISetLastError
(
const
char
*
msg
);
...
@@ -33,7 +34,7 @@ void TVMAPISetLastError(const char* msg);
...
@@ -33,7 +34,7 @@ void TVMAPISetLastError(const char* msg);
* \param e the exception
* \param e the exception
* \return the return value of API after exception is handled
* \return the return value of API after exception is handled
*/
*/
inline
int
TVMAPIHandleException
(
const
dmlc
::
E
rror
&
e
)
{
inline
int
TVMAPIHandleException
(
const
std
::
runtime_e
rror
&
e
)
{
TVMAPISetLastError
(
e
.
what
());
TVMAPISetLastError
(
e
.
what
());
return
-
1
;
return
-
1
;
}
}
...
...
src/c_api/c_api_ir.cc
0 → 100644
View file @
622cee7a
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to IR build
* \file c_api_ir.cc
*/
#include <tvm/expr.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
namespace
tvm
{
using
namespace
Halide
::
Internal
;
using
ArgStack
=
const
std
::
vector
<
APIVariantValue
>
;
using
RetValue
=
APIVariantValue
;
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = Node::make(args.at(0)); \
}) \
#define REGISTER_MAKE2(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = Node::make(args.at(0), args.at(1)); \
}) \
#define REGISTER_MAKE3(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = Node::make(args.at(0), args.at(1), args.at(2)); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
Expr a = args.at(0), b = args.at(1); \
match_types(a, b); \
*ret = Node::make(a, b); \
}) \
.add_argument("lhs", "Expr", "left operand") \
.add_argument("rhs", "Expr", "right operand")
REGISTER_MAKE2
(
IntImm
);
REGISTER_MAKE2
(
UIntImm
);
REGISTER_MAKE2
(
FloatImm
);
REGISTER_MAKE1
(
StringImm
);
REGISTER_MAKE_BINARY_OP
(
Add
);
REGISTER_MAKE_BINARY_OP
(
Sub
);
REGISTER_MAKE_BINARY_OP
(
Mul
);
REGISTER_MAKE_BINARY_OP
(
Div
);
REGISTER_MAKE_BINARY_OP
(
Mod
);
REGISTER_MAKE_BINARY_OP
(
Min
);
REGISTER_MAKE_BINARY_OP
(
Max
);
REGISTER_MAKE_BINARY_OP
(
EQ
);
REGISTER_MAKE_BINARY_OP
(
NE
);
REGISTER_MAKE_BINARY_OP
(
LT
);
REGISTER_MAKE_BINARY_OP
(
LE
);
REGISTER_MAKE_BINARY_OP
(
GT
);
REGISTER_MAKE_BINARY_OP
(
GE
);
REGISTER_MAKE_BINARY_OP
(
And
);
REGISTER_MAKE_BINARY_OP
(
Or
);
REGISTER_MAKE1
(
Not
);
REGISTER_MAKE3
(
Select
);
REGISTER_MAKE3
(
Ramp
);
REGISTER_MAKE2
(
Broadcast
);
REGISTER_MAKE3
(
Let
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE2
(
AssertStmt
);
REGISTER_MAKE3
(
ProducerConsumer
);
// TODO(tqchen) For;
REGISTER_MAKE3
(
Store
);
// TODO(tqchen) Provide;
// TODO(tqchen) Allocate;
REGISTER_MAKE1
(
Free
);
// TODO(tqchen) Realize;
REGISTER_MAKE2
(
Block
);
REGISTER_MAKE3
(
IfThenElse
);
REGISTER_MAKE1
(
Evaluate
);
}
// namespace tvm
src/c_api/c_api_registry.h
View file @
622cee7a
...
@@ -24,7 +24,7 @@ inline std::string Type2String(const Type& t) {
...
@@ -24,7 +24,7 @@ inline std::string Type2String(const Type& t) {
inline
Type
String2Type
(
std
::
string
s
)
{
inline
Type
String2Type
(
std
::
string
s
)
{
std
::
istringstream
is
(
s
);
std
::
istringstream
is
(
s
);
halide_type_code_t
code
;
halide_type_code_t
code
=
Type
::
Int
;
if
(
s
.
substr
(
0
,
3
)
==
"int"
)
{
if
(
s
.
substr
(
0
,
3
)
==
"int"
)
{
code
=
Type
::
Int
;
s
=
s
.
substr
(
3
);
code
=
Type
::
Int
;
s
=
s
.
substr
(
3
);
}
else
if
(
s
.
substr
(
0
,
4
)
==
"uint"
)
{
}
else
if
(
s
.
substr
(
0
,
4
)
==
"uint"
)
{
...
@@ -36,7 +36,7 @@ inline Type String2Type(std::string s) {
...
@@ -36,7 +36,7 @@ inline Type String2Type(std::string s) {
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"unknown type "
<<
s
;
LOG
(
FATAL
)
<<
"unknown type "
<<
s
;
}
}
int
bits
,
lanes
=
1
;
int
bits
=
32
,
lanes
=
1
;
if
(
sscanf
(
s
.
c_str
(),
"%dx%d"
,
&
bits
,
&
lanes
)
==
0
)
{
if
(
sscanf
(
s
.
c_str
(),
"%dx%d"
,
&
bits
,
&
lanes
)
==
0
)
{
LOG
(
FATAL
)
<<
"unknown type "
<<
s
;
LOG
(
FATAL
)
<<
"unknown type "
<<
s
;
}
}
...
@@ -109,12 +109,20 @@ struct APIVariantValue {
...
@@ -109,12 +109,20 @@ struct APIVariantValue {
CHECK_EQ
(
type_id
,
kLong
);
CHECK_EQ
(
type_id
,
kLong
);
return
v_union
.
v_long
;
return
v_union
.
v_long
;
}
}
inline
operator
uint64_t
()
const
{
CHECK_EQ
(
type_id
,
kLong
);
return
v_union
.
v_long
;
}
inline
operator
int
()
const
{
inline
operator
int
()
const
{
CHECK_EQ
(
type_id
,
kLong
);
CHECK_EQ
(
type_id
,
kLong
);
CHECK_LE
(
v_union
.
v_long
,
CHECK_LE
(
v_union
.
v_long
,
std
::
numeric_limits
<
int
>::
max
());
std
::
numeric_limits
<
int
>::
max
());
return
v_union
.
v_long
;
return
v_union
.
v_long
;
}
}
inline
operator
bool
()
const
{
CHECK_EQ
(
type_id
,
kLong
);
return
v_union
.
v_long
!=
0
;
}
inline
operator
std
::
string
()
const
{
inline
operator
std
::
string
()
const
{
CHECK_EQ
(
type_id
,
kStr
);
CHECK_EQ
(
type_id
,
kStr
);
return
str
;
return
str
;
...
...
tests/python/test_basic.py
View file @
622cee7a
...
@@ -2,8 +2,24 @@ import tvm
...
@@ -2,8 +2,24 @@ import tvm
def
test_const
():
def
test_const
():
x
=
tvm
.
const
(
1
)
x
=
tvm
.
const
(
1
)
assert
x
.
type
==
'int32'
assert
x
.
d
type
==
'int32'
assert
isinstance
(
x
,
tvm
.
expr
.
IntImm
)
assert
isinstance
(
x
,
tvm
.
expr
.
IntImm
)
def
test_make
():
x
=
tvm
.
const
(
1
)
y
=
tvm
.
make
.
IntImm
(
'int32'
,
1
)
z
=
x
+
y
print
tvm
.
format_str
(
z
)
def
test_ir
():
x
=
tvm
.
const
(
1
)
y
=
tvm
.
make
.
IntImm
(
'int32'
,
1
)
z
=
x
+
y
stmt
=
tvm
.
make
.
Evaluate
(
z
)
assert
isinstance
(
stmt
,
tvm
.
stmt
.
Evaluate
)
print
tvm
.
format_str
(
stmt
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_const
()
test_const
()
test_make
()
test_ir
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment