Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5324b211
Commit
5324b211
authored
Oct 18, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[API] expose dir
parent
8de0a083
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
237 additions
and
2 deletions
+237
-2
include/tvm/base.h
+35
-0
include/tvm/c_api.h
+10
-0
include/tvm/domain.h
+5
-1
include/tvm/expr.h
+9
-0
include/tvm/expr_node.h
+42
-0
include/tvm/expr_util.h
+32
-0
include/tvm/op.h
+18
-0
python/tvm/cpp/_ctypes/_api.py
+22
-0
python/tvm/cpp/expr.py
+1
-1
src/c_api/c_api.cc
+46
-0
src/c_api/c_api_function.cc
+7
-0
src/expr/op.cc
+7
-0
tests/python/test_cpp.py
+3
-0
No files found.
include/tvm/base.h
View file @
5324b211
...
...
@@ -131,6 +131,20 @@ class NodeRef {
inline
NodeType
node_type
()
const
;
/*! \return wheyjer the expression is null */
inline
bool
is_null
()
const
;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline
bool
operator
==
(
const
NodeRef
&
other
)
const
;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline
bool
operator
!=
(
const
NodeRef
&
other
)
const
;
/*! \return the hash function for NodeRef */
inline
size_t
hash
()
const
;
protected
:
template
<
typename
T
,
typename
>
...
...
@@ -182,5 +196,26 @@ inline bool NodeRef::is_null() const {
return
node_
.
get
()
==
nullptr
;
}
inline
bool
NodeRef
::
operator
==
(
const
NodeRef
&
other
)
const
{
return
node_
.
get
()
==
other
.
node_
.
get
();
}
inline
bool
NodeRef
::
operator
!=
(
const
NodeRef
&
other
)
const
{
return
node_
.
get
()
!=
other
.
node_
.
get
();
}
inline
size_t
NodeRef
::
hash
()
const
{
return
std
::
hash
<
Node
*>
()(
node_
.
get
());
}
}
// namespace tvm
namespace
std
{
template
<>
struct
hash
<::
tvm
::
NodeRef
>
{
std
::
size_t
operator
()(
const
::
tvm
::
NodeRef
&
k
)
const
{
return
k
.
hash
();
}
};
}
// namespace std
#endif // TVM_BASE_H_
include/tvm/c_api.h
View file @
5324b211
...
...
@@ -136,4 +136,14 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
ArgVariant
*
out_value
,
int
*
out_typeid
);
/*!
* \brief get attributes names in the node.
* \param handle The node handle
* \param out_size The number of functions
* \param out_array The array of function names.
*/
TVM_DLL
int
TVMNodeListAttrNames
(
NodeHandle
handle
,
int
*
out_size
,
const
char
***
out_array
);
#endif // TVM_C_API_H_
include/tvm/domain.h
View file @
5324b211
...
...
@@ -14,9 +14,13 @@
namespace
tvm
{
//
using Domain = Array<Range>;
//using Domain = Array<Range>;
class
RDomain
:
public
NodeRef
{
};
}
// namespace tvm
...
...
include/tvm/expr.h
View file @
5324b211
...
...
@@ -113,4 +113,13 @@ inline Expr constant(T value) {
}
}
// namespace tvm
namespace
std
{
template
<>
struct
hash
<::
tvm
::
Expr
>
{
std
::
size_t
operator
()(
const
::
tvm
::
NodeRef
&
k
)
const
{
return
k
.
hash
();
}
};
}
// namespace std
#endif // TVM_EXPR_H_
include/tvm/expr_node.h
View file @
5324b211
...
...
@@ -46,6 +46,7 @@ class IntNode : public ExprNode {
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{
visitor
->
Visit
(
"value"
,
&
value
);
visitor
->
Visit
(
"dtype"
,
&
dtype_
);
}
};
...
...
@@ -64,6 +65,7 @@ class FloatNode : public ExprNode {
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{
visitor
->
Visit
(
"value"
,
&
value
);
visitor
->
Visit
(
"dtype"
,
&
dtype_
);
}
};
...
...
@@ -94,6 +96,7 @@ class UnaryOpNode : public ExprNode {
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{
visitor
->
Visit
(
"op"
,
&
op
);
visitor
->
Visit
(
"dtype"
,
&
dtype_
);
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"src"
,
&
src
);
...
...
@@ -130,12 +133,51 @@ struct BinaryOpNode : public ExprNode {
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{
visitor
->
Visit
(
"op"
,
&
op
);
visitor
->
Visit
(
"dtype"
,
&
dtype_
);
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"lhs"
,
&
lhs
);
fvisit
(
"rhs"
,
&
rhs
);
}
};
/*! \brief Binary mapping operator */
struct
ReduceNode
:
public
ExprNode
{
public
:
/*! \brief The operator */
const
BinaryOp
*
op
;
/*! \brief The source operand */
Expr
src
;
/*! \brief The reduction domain */
RDomain
rdom
;
/*! \brief constructor, do not use constructor */
ReduceNode
()
{
node_type_
=
kReduceNode
;
}
ReduceNode
(
const
BinaryOp
*
op
,
Expr
&&
src
,
RDomain
&&
rdom
)
:
op
(
op
),
src
(
std
::
move
(
src
)),
rdom
(
std
::
move
(
rdom
))
{
node_type_
=
kReduceNode
;
dtype_
=
this
->
src
.
dtype
();
}
~
ReduceNode
()
{
this
->
Destroy
();
}
const
char
*
type_key
()
const
override
{
return
"ReduceNode"
;
}
void
Verify
()
const
override
{
CHECK_EQ
(
dtype_
,
src
.
dtype
());
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{
visitor
->
Visit
(
"op"
,
&
op
);
visitor
->
Visit
(
"dtype"
,
&
dtype_
);
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"src"
,
&
src
);
fvisit
(
"rdom"
,
&
rdom
);
}
};
}
// namespace tvm
#endif // TVM_EXPR_NODE_H_
include/tvm/expr_util.h
View file @
5324b211
...
...
@@ -7,9 +7,41 @@
#define TVM_EXPR_UTIL_H_
#include "./expr.h"
#include "./expr_node.h"
namespace
tvm
{
/*!
* \brief simplify the expression src
* \param src The source expression
* \return the simplified expression.
*/
Expr
Simplify
(
const
Expr
&
src
);
/*!
* \brief visit the exression node in expr tree in post DFS order.
* \param expr The expression tree
* \param fvisit The visit function.
*/
template
<
typename
FVisit
>
inline
void
Visit
(
const
Expr
&
expr
,
FVisit
fvisit
)
{
// TODO(tqchen) change to stack based impl.
switch
(
expr
.
node_type
())
{
case
kBinaryOpNode
:
{
const
auto
*
n
=
expr
.
Get
<
BinaryOpNode
>
();
Visit
(
n
->
lhs
,
fvisit
);
Visit
(
n
->
rhs
,
fvisit
);
break
;
}
case
kUnaryOpNode
:
{
const
auto
*
n
=
expr
.
Get
<
UnaryOpNode
>
();
Visit
(
n
->
src
,
fvisit
);
break
;
}
default
:
break
;
}
fvisit
(
expr
);
}
}
// namespace tvm
...
...
include/tvm/op.h
View file @
5324b211
...
...
@@ -9,6 +9,7 @@
#include <dmlc/registry.h>
#include <string>
#include "./expr.h"
#include "./domain.h"
namespace
tvm
{
...
...
@@ -27,6 +28,13 @@ class BinaryOp {
*/
Expr
operator
()(
Expr
lhs
,
Expr
rhs
)
const
;
/*!
* \brief make a reduction of src over rdom,
* \param src Source expression.
* \param rdom reduction domain.
* \return the result expr
*/
Expr
Reduce
(
Expr
src
,
RDomain
rdom
)
const
;
/*!
* \brief get binary op by name
* \param name name of operator
*/
...
...
@@ -112,6 +120,12 @@ class MinOp : public BinaryOp {
return (*op)(lhs, rhs); \
}
#define DEFINE_REDUCE_FUNCTION(FuncName, OpName) \
inline Expr FuncName(Expr src, RDomain rdom) { \
static const BinaryOp* op = BinaryOp::Get(#OpName); \
return op->Reduce(src, rdom); \
}
DEFINE_BINARY_OP_OVERLOAD
(
+
);
DEFINE_BINARY_OP_OVERLOAD
(
-
);
DEFINE_BINARY_OP_OVERLOAD
(
*
);
...
...
@@ -120,6 +134,10 @@ DEFINE_BINARY_OP_OVERLOAD(/);
DEFINE_BINARY_OP_FUNCTION
(
max
);
DEFINE_BINARY_OP_FUNCTION
(
min
);
DEFINE_REDUCE_FUNCTION
(
max
,
max
);
DEFINE_REDUCE_FUNCTION
(
min
,
min
);
DEFINE_REDUCE_FUNCTION
(
sum
,
+
);
// overload negation
inline
Expr
operator
-
(
Expr
src
)
{
return
src
*
(
-
1
);
...
...
python/tvm/cpp/_ctypes/_api.py
View file @
5324b211
...
...
@@ -11,6 +11,7 @@ from .._base import _LIB
from
.._base
import
c_str
,
py_str
,
string_types
from
.._base
import
FunctionHandle
,
NodeHandle
from
.._base
import
check_call
,
ctypes2docstring
from
..
import
_function_internal
class
ArgVariant
(
ctypes
.
Union
):
...
...
@@ -71,6 +72,27 @@ class NodeBase(object):
ctypes
.
byref
(
ret_val
),
ctypes
.
byref
(
ret_typeid
)))
return
RET_SWITCH
[
ret_typeid
.
value
](
ret_val
)
def
__hash__
(
self
):
return
_function_internal
.
_raw_ptr
(
self
)
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
NodeBase
):
return
False
return
self
.
__hash__
()
==
other
.
__hash__
()
def
__ne__
(
self
,
other
):
return
not
self
.
__eq__
(
other
)
def
__dir__
(
self
):
plist
=
ctypes
.
POINTER
(
ctypes
.
c_char_p
)()
size
=
ctypes
.
c_uint
()
check_call
(
_LIB
.
TVMNodeListAttrNames
(
self
.
handle
,
ctypes
.
byref
(
size
),
ctypes
.
byref
(
plist
)))
names
=
[]
for
i
in
range
(
size
.
value
):
names
.
append
(
py_str
(
plist
[
i
]))
return
names
def
_push_arg
(
arg
):
a
=
ArgVariant
()
...
...
python/tvm/cpp/expr.py
View file @
5324b211
...
...
@@ -42,5 +42,5 @@ class Var(Expr):
pass
@register_node
(
"BinaryOpNode"
)
class
BinaryOp
Node
(
Expr
):
class
BinaryOp
Expr
(
Expr
):
pass
src/c_api/c_api.cc
View file @
5324b211
...
...
@@ -59,6 +59,29 @@ struct APIAttrGetter : public AttrVisitor {
}
};
struct
APIAttrDir
:
public
AttrVisitor
{
std
::
vector
<
std
::
string
>*
names
;
void
Visit
(
const
char
*
key
,
double
*
value
)
override
{
names
->
push_back
(
key
);
}
void
Visit
(
const
char
*
key
,
int64_t
*
value
)
override
{
names
->
push_back
(
key
);
}
void
Visit
(
const
char
*
key
,
DataType
*
value
)
override
{
names
->
push_back
(
key
);
}
void
Visit
(
const
char
*
key
,
std
::
string
*
value
)
override
{
names
->
push_back
(
key
);
}
void
Visit
(
const
char
*
key
,
const
UnaryOp
**
value
)
override
{
names
->
push_back
(
key
);
}
void
Visit
(
const
char
*
key
,
const
BinaryOp
**
value
)
override
{
names
->
push_back
(
key
);
}
};
const
char
*
TVMGetLastError
()
{
return
TVMAPIThreadLocalStore
::
Get
()
->
last_error
.
c_str
();
}
...
...
@@ -190,6 +213,29 @@ int TVMNodeGetAttr(NodeHandle handle,
API_END_HANDLE_ERROR
(
ret
->
Clear
());
}
int
TVMNodeListAttrNames
(
NodeHandle
handle
,
int
*
out_size
,
const
char
***
out_array
)
{
TVMAPIThreadLocalEntry
*
ret
=
TVMAPIThreadLocalStore
::
Get
();
API_BEGIN
();
ret
->
ret_vec_str
.
clear
();
TVMAPINode
*
tnode
=
static_cast
<
TVMAPINode
*>
(
handle
);
APIAttrDir
dir
;
dir
.
names
=
&
(
ret
->
ret_vec_str
);
(
*
tnode
)
->
VisitAttrs
(
&
dir
);
(
*
tnode
)
->
VisitNodeRefFields
([
ret
](
const
char
*
key
,
NodeRef
*
ref
)
{
ret
->
ret_vec_str
.
push_back
(
key
);
});
ret
->
ret_vec_charp
.
clear
();
for
(
size_t
i
=
0
;
i
<
ret
->
ret_vec_str
.
size
();
++
i
)
{
ret
->
ret_vec_charp
.
push_back
(
ret
->
ret_vec_str
[
i
].
c_str
());
}
*
out_array
=
dmlc
::
BeginPtr
(
ret
->
ret_vec_charp
);
*
out_size
=
static_cast
<
int
>
(
ret
->
ret_vec_str
.
size
());
API_END
();
}
inline
void
TVMAPIThreadLocalEntry
::
SetReturn
(
ArgVariant
*
ret_val
,
int
*
ret_typeid
)
{
APIVariantValue
&
rv
=
ret_value
;
...
...
src/c_api/c_api_function.cc
View file @
5324b211
...
...
@@ -46,6 +46,13 @@ TVM_REGISTER_API(_binary_op)
.
add_argument
(
"lhs"
,
"Expr"
,
"left operand"
)
.
add_argument
(
"rhs"
,
"Expr"
,
"right operand"
);
TVM_REGISTER_API
(
_raw_ptr
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
CHECK
(
args
.
at
(
0
).
type_id
==
kNodeHandle
);
*
ret
=
reinterpret_cast
<
int64_t
>
(
args
.
at
(
0
).
sptr
.
get
());
})
.
add_argument
(
"src"
,
"NodeBase"
,
"the node base"
);
// transformations
TVM_REGISTER_API
(
format_str
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
...
...
src/expr/op.cc
View file @
5324b211
...
...
@@ -20,6 +20,13 @@ Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
return
Expr
(
std
::
move
(
nptr
));
}
Expr
BinaryOp
::
Reduce
(
Expr
src
,
RDomain
rdom
)
const
{
auto
nptr
=
std
::
make_shared
<
ReduceNode
>
(
this
,
std
::
move
(
src
),
std
::
move
(
rdom
));
nptr
->
Verify
();
return
Expr
(
std
::
move
(
nptr
));
}
const
BinaryOp
*
BinaryOp
::
Get
(
const
char
*
name
)
{
const
auto
*
op
=
dmlc
::
Registry
<
BinaryOpReg
>::
Find
(
name
);
CHECK
(
op
!=
nullptr
)
<<
"cannot find "
<<
name
;
...
...
tests/python/test_cpp.py
View file @
5324b211
...
...
@@ -4,6 +4,9 @@ def test_basic():
a
=
tvm
.
Var
(
'a'
)
b
=
tvm
.
Var
(
'b'
)
c
=
a
+
b
assert
a
==
c
.
lhs
assert
c
.
dtype
==
tvm
.
int32
assert
tvm
.
format_str
(
c
)
==
'(
%
s +
%
s)'
%
(
a
.
name
,
b
.
name
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment