Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
8de0a083
Commit
8de0a083
authored
Oct 18, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[OP] enable binary op
parent
1a7fb9f9
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
177 additions
and
40 deletions
+177
-40
include/tvm/op.h
+50
-17
python/tvm/cpp/expr.py
+35
-1
python/tvm/cpp/function.py
+56
-1
src/c_api/c_api_function.cc
+15
-7
src/expr/op.cc
+19
-12
tests/python/test_cpp.py
+2
-2
No files found.
include/tvm/op.h
View file @
8de0a083
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#ifndef TVM_OP_H_
#ifndef TVM_OP_H_
#define TVM_OP_H_
#define TVM_OP_H_
#include <dmlc/registry.h>
#include <string>
#include <string>
#include "./expr.h"
#include "./expr.h"
...
@@ -14,6 +15,8 @@ namespace tvm {
...
@@ -14,6 +15,8 @@ namespace tvm {
/*! \brief binary operator */
/*! \brief binary operator */
class
BinaryOp
{
class
BinaryOp
{
public
:
public
:
// virtual destructor
virtual
~
BinaryOp
()
{}
/*! \return the function name to be called in binary op */
/*! \return the function name to be called in binary op */
virtual
const
char
*
FunctionName
()
const
=
0
;
virtual
const
char
*
FunctionName
()
const
=
0
;
/*!
/*!
...
@@ -23,6 +26,11 @@ class BinaryOp {
...
@@ -23,6 +26,11 @@ class BinaryOp {
* \return the result expr
* \return the result expr
*/
*/
Expr
operator
()(
Expr
lhs
,
Expr
rhs
)
const
;
Expr
operator
()(
Expr
lhs
,
Expr
rhs
)
const
;
/*!
* \brief get binary op by name
* \param name name of operator
*/
static
const
BinaryOp
*
Get
(
const
char
*
name
);
};
};
...
@@ -37,6 +45,11 @@ class UnaryOp {
...
@@ -37,6 +45,11 @@ class UnaryOp {
* \return the result expr
* \return the result expr
*/
*/
Expr
operator
()(
Expr
src
)
const
;
Expr
operator
()(
Expr
src
)
const
;
/*!
* \brief get unary op by name
* \param name name of operator
*/
static
const
UnaryOp
*
Get
(
const
char
*
name
);
};
};
...
@@ -45,7 +58,6 @@ class AddOp : public BinaryOp {
...
@@ -45,7 +58,6 @@ class AddOp : public BinaryOp {
const
char
*
FunctionName
()
const
override
{
const
char
*
FunctionName
()
const
override
{
return
"+"
;
return
"+"
;
}
}
static
AddOp
*
Get
();
};
};
...
@@ -54,7 +66,6 @@ class SubOp : public BinaryOp {
...
@@ -54,7 +66,6 @@ class SubOp : public BinaryOp {
const
char
*
FunctionName
()
const
override
{
const
char
*
FunctionName
()
const
override
{
return
"-"
;
return
"-"
;
}
}
static
SubOp
*
Get
();
};
};
...
@@ -63,7 +74,6 @@ class MulOp : public BinaryOp {
...
@@ -63,7 +74,6 @@ class MulOp : public BinaryOp {
const
char
*
FunctionName
()
const
override
{
const
char
*
FunctionName
()
const
override
{
return
"*"
;
return
"*"
;
}
}
static
MulOp
*
Get
();
};
};
...
@@ -72,7 +82,6 @@ class DivOp : public BinaryOp {
...
@@ -72,7 +82,6 @@ class DivOp : public BinaryOp {
const
char
*
FunctionName
()
const
override
{
const
char
*
FunctionName
()
const
override
{
return
"/"
;
return
"/"
;
}
}
static
DivOp
*
Get
();
};
};
...
@@ -81,7 +90,6 @@ class MaxOp : public BinaryOp {
...
@@ -81,7 +90,6 @@ class MaxOp : public BinaryOp {
const
char
*
FunctionName
()
const
override
{
const
char
*
FunctionName
()
const
override
{
return
"max"
;
return
"max"
;
}
}
static
MaxOp
*
Get
();
};
};
...
@@ -90,32 +98,57 @@ class MinOp : public BinaryOp {
...
@@ -90,32 +98,57 @@ class MinOp : public BinaryOp {
const
char
*
FunctionName
()
const
override
{
const
char
*
FunctionName
()
const
override
{
return
"min"
;
return
"min"
;
}
}
static
MinOp
*
Get
();
};
};
#define DEFINE_
OP_OVERLOAD(OpChar, OpName)
\
#define DEFINE_
BINARY_OP_OVERLOAD(OpChar)
\
inline Expr operator OpChar (Expr lhs, Expr rhs) { \
inline Expr operator OpChar (Expr lhs, Expr rhs) { \
return (*OpName::Get())(lhs, rhs); \
static const BinaryOp* op = BinaryOp::Get(#OpChar); \
return (*op)(lhs, rhs); \
}
}
#define DEFINE_BINARY_OP_FUNCTION(FuncName, OpName) \
#define DEFINE_BINARY_OP_FUNCTION(FuncName) \
inline Expr FuncName(Expr lhs, Expr rhs) { \
inline Expr FuncName(Expr lhs, Expr rhs) { \
return (*OpName::Get())(lhs, rhs); \
static const BinaryOp* op = BinaryOp::Get(#FuncName); \
return (*op)(lhs, rhs); \
}
}
DEFINE_
OP_OVERLOAD
(
+
,
AddOp
);
DEFINE_
BINARY_OP_OVERLOAD
(
+
);
DEFINE_
OP_OVERLOAD
(
-
,
SubOp
);
DEFINE_
BINARY_OP_OVERLOAD
(
-
);
DEFINE_
OP_OVERLOAD
(
*
,
MulOp
);
DEFINE_
BINARY_OP_OVERLOAD
(
*
);
DEFINE_
OP_OVERLOAD
(
/
,
DivOp
);
DEFINE_
BINARY_OP_OVERLOAD
(
/
);
DEFINE_BINARY_OP_FUNCTION
(
max
,
MaxOp
);
DEFINE_BINARY_OP_FUNCTION
(
max
);
DEFINE_BINARY_OP_FUNCTION
(
min
,
MinOp
);
DEFINE_BINARY_OP_FUNCTION
(
min
);
// overload negation
// overload negation
inline
Expr
operator
-
(
Expr
src
)
{
inline
Expr
operator
-
(
Expr
src
)
{
return
src
*
(
-
1
);
return
src
*
(
-
1
);
}
}
// template of op registry
template
<
typename
Op
>
struct
OpReg
{
std
::
string
name
;
std
::
unique_ptr
<
Op
>
op
;
inline
OpReg
&
set
(
Op
*
op
)
{
this
->
op
.
reset
(
op
);
return
*
this
;
}
};
using
UnaryOpReg
=
OpReg
<
UnaryOp
>
;
using
BinaryOpReg
=
OpReg
<
BinaryOp
>
;
#define TVM_REGISTER_BINARY_OP(FunctionName, TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::BinaryOpReg & __make_ ## _BinOp_ ## TypeName = \
::dmlc::Registry<::tvm::BinaryOpReg>::Get()->__REGISTER_OR_GET__(#FunctionName) \
.set(new TypeName())
#define TVM_REGISTER_UNARY_OP(FunctionName, TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::BinaryOpReg & __make_ ## _BinOp_ ## TypeName = \
::dmlc::Registry<::tvm::UnaryOpReg>::Get()->__REGISTER_OR_GET__(#FunctionName) \
.set(new TypeName())
}
// namespace tvm
}
// namespace tvm
#endif // TVM_OP_H_
#endif // TVM_OP_H_
python/tvm/cpp/expr.py
View file @
8de0a083
from
._ctypes._api
import
NodeBase
,
register_node
from
._ctypes._api
import
NodeBase
,
register_node
from
.function
import
binary_op
from
._function_internal
import
_binary_op
class
Expr
(
NodeBase
):
class
Expr
(
NodeBase
):
pass
def
__add__
(
self
,
other
):
return
binary_op
(
'+'
,
self
,
other
)
def
__radd__
(
self
,
other
):
return
self
.
__add__
(
other
)
def
__sub__
(
self
,
other
):
return
binary_op
(
'-'
,
self
,
other
)
def
__rsub__
(
self
,
other
):
return
binary_op
(
'-'
,
other
,
self
)
def
__mul__
(
self
,
other
):
return
binary_op
(
'*'
,
self
,
other
)
def
__rmul__
(
self
,
other
):
return
binary_op
(
'*'
,
other
,
self
)
def
__div__
(
self
,
other
):
return
binary_op
(
'/'
,
self
,
other
)
def
__rdiv__
(
self
,
other
):
return
binary_op
(
'/'
,
other
,
self
)
def
__truediv__
(
self
,
other
):
return
self
.
__div__
(
other
)
def
__rtruediv__
(
self
,
other
):
return
self
.
__rdiv__
(
other
)
def
__neg__
(
self
):
return
self
.
__mul__
(
-
1
)
@register_node
(
"VarNode"
)
@register_node
(
"VarNode"
)
class
Var
(
Expr
):
class
Var
(
Expr
):
...
...
python/tvm/cpp/function.py
View file @
8de0a083
from
__future__
import
absolute_import
as
_abs
from
numbers
import
Number
as
_Number
from
._ctypes._api
import
_init_function_module
from
._ctypes._api
import
_init_function_module
import
_function_internal
from
.
import
_function_internal
int32
=
1
int32
=
1
float32
=
2
float32
=
2
...
@@ -18,4 +20,57 @@ def Var(name="tindex", dtype=int32):
...
@@ -18,4 +20,57 @@ def Var(name="tindex", dtype=int32):
return
_function_internal
.
_Var
(
name
,
dtype
)
return
_function_internal
.
_Var
(
name
,
dtype
)
def
_symbol
(
value
):
"""Convert a value to expression."""
if
isinstance
(
value
,
_Number
):
return
constant
(
value
)
else
:
return
value
def
binary_op
(
op
,
lhs
,
rhs
):
"""Binary operator given op lhs and rhs
Parameters
----------
op : str
The operator string
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return
_function_internal
.
_binary_op
(
op
,
_symbol
(
lhs
),
_symbol
(
rhs
))
def
max
(
lhs
,
rhs
):
"""Max of two expressions
Parameters
----------
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return
binary_op
(
"max"
,
lhs
,
rhs
)
def
min
(
lhs
,
rhs
):
"""Min of two expressions
Parameters
----------
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return
binary_op
(
"max"
,
lhs
,
rhs
)
_init_function_module
(
"tvm.cpp"
)
_init_function_module
(
"tvm.cpp"
)
src/c_api/c_api_function.cc
View file @
8de0a083
...
@@ -16,6 +16,7 @@ namespace tvm {
...
@@ -16,6 +16,7 @@ namespace tvm {
using
ArgStack
=
const
std
::
vector
<
APIVariantValue
>
;
using
ArgStack
=
const
std
::
vector
<
APIVariantValue
>
;
using
RetValue
=
APIVariantValue
;
using
RetValue
=
APIVariantValue
;
// expression logic x
TVM_REGISTER_API
(
_Var
)
TVM_REGISTER_API
(
_Var
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
Var
(
args
.
at
(
0
),
*
ret
=
Var
(
args
.
at
(
0
),
...
@@ -24,21 +25,28 @@ TVM_REGISTER_API(_Var)
...
@@ -24,21 +25,28 @@ TVM_REGISTER_API(_Var)
.
add_argument
(
"name"
,
"str"
,
"name of the var"
)
.
add_argument
(
"name"
,
"str"
,
"name of the var"
)
.
add_argument
(
"dtype"
,
"int"
,
"data type of var"
);
.
add_argument
(
"dtype"
,
"int"
,
"data type of var"
);
TVM_REGISTER_API
(
constant
)
TVM_REGISTER_API
(
max
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
max
(
args
.
at
(
0
),
args
.
at
(
1
));
if
(
args
.
at
(
0
).
type_id
==
kLong
)
{
*
ret
=
IntConstant
(
args
.
at
(
0
));
}
else
if
(
args
.
at
(
0
).
type_id
==
kDouble
)
{
*
ret
=
FloatConstant
(
args
.
at
(
0
));
}
else
{
LOG
(
FATAL
)
<<
"only accept int or float"
;
}
})
})
.
add_argument
(
"lhs"
,
"Expr"
,
"left operand"
)
.
add_argument
(
"src"
,
"Number"
,
"source number"
);
.
add_argument
(
"rhs"
,
"Expr"
,
"right operand"
);
TVM_REGISTER_API
(
min
)
TVM_REGISTER_API
(
_binary_op
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
min
(
args
.
at
(
0
),
args
.
at
(
1
));
CHECK
(
args
.
at
(
0
).
type_id
==
kStr
);
*
ret
=
(
*
BinaryOp
::
Get
(
args
.
at
(
0
).
str
.
c_str
()))(
args
.
at
(
1
),
args
.
at
(
2
));
})
})
.
add_argument
(
"op"
,
"str"
,
"operator"
)
.
add_argument
(
"lhs"
,
"Expr"
,
"left operand"
)
.
add_argument
(
"lhs"
,
"Expr"
,
"left operand"
)
.
add_argument
(
"rhs"
,
"Expr"
,
"right operand"
);
.
add_argument
(
"rhs"
,
"Expr"
,
"right operand"
);
// transformations
TVM_REGISTER_API
(
format_str
)
TVM_REGISTER_API
(
format_str
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
std
::
ostringstream
os
;
std
::
ostringstream
os
;
...
...
src/expr/op.cc
View file @
8de0a083
...
@@ -5,6 +5,12 @@
...
@@ -5,6 +5,12 @@
#include <tvm/op.h>
#include <tvm/op.h>
#include <tvm/expr_node.h>
#include <tvm/expr_node.h>
namespace
dmlc
{
DMLC_REGISTRY_ENABLE
(
::
tvm
::
BinaryOpReg
);
DMLC_REGISTRY_ENABLE
(
::
tvm
::
UnaryOpReg
);
}
// namespace dmlc
namespace
tvm
{
namespace
tvm
{
Expr
BinaryOp
::
operator
()(
Expr
lhs
,
Expr
rhs
)
const
{
Expr
BinaryOp
::
operator
()(
Expr
lhs
,
Expr
rhs
)
const
{
...
@@ -14,17 +20,18 @@ Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
...
@@ -14,17 +20,18 @@ Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
return
Expr
(
std
::
move
(
nptr
));
return
Expr
(
std
::
move
(
nptr
));
}
}
#define DEFINE_SINGLETON_GET(TypeName) \
const
BinaryOp
*
BinaryOp
::
Get
(
const
char
*
name
)
{
TypeName* TypeName::Get() { \
const
auto
*
op
=
dmlc
::
Registry
<
BinaryOpReg
>::
Find
(
name
);
static TypeName inst; \
CHECK
(
op
!=
nullptr
)
<<
"cannot find "
<<
name
;
return &inst; \
return
op
->
op
.
get
();
}
}
DEFINE_SINGLETON_GET
(
AddOp
);
TVM_REGISTER_BINARY_OP
(
+
,
AddOp
);
DEFINE_SINGLETON_GET
(
SubOp
);
TVM_REGISTER_BINARY_OP
(
-
,
SubOp
);
DEFINE_SINGLETON_GET
(
MulOp
);
TVM_REGISTER_BINARY_OP
(
*
,
MulOp
);
DEFINE_SINGLETON_GET
(
DivOp
);
TVM_REGISTER_BINARY_OP
(
/
,
DivOp
);
DEFINE_SINGLETON_GET
(
MaxOp
);
TVM_REGISTER_BINARY_OP
(
max
,
MaxOp
);
DEFINE_SINGLETON_GET
(
MinOp
);
TVM_REGISTER_BINARY_OP
(
min
,
MinOp
);
}
// namespace tvm
}
// namespace tvm
tests/python/test_cpp.py
View file @
8de0a083
...
@@ -3,8 +3,8 @@ from tvm import cpp as tvm
...
@@ -3,8 +3,8 @@ from tvm import cpp as tvm
def
test_basic
():
def
test_basic
():
a
=
tvm
.
Var
(
'a'
)
a
=
tvm
.
Var
(
'a'
)
b
=
tvm
.
Var
(
'b'
)
b
=
tvm
.
Var
(
'b'
)
z
=
tvm
.
max
(
a
,
b
)
c
=
a
+
b
assert
tvm
.
format_str
(
z
)
==
'max(
%
s,
%
s)'
%
(
a
.
name
,
b
.
name
)
assert
tvm
.
format_str
(
c
)
==
'(
%
s +
%
s)'
%
(
a
.
name
,
b
.
name
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_basic
()
test_basic
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment