Commit 8de0a083 by tqchen

[OP] enable binary op

parent 1a7fb9f9
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#ifndef TVM_OP_H_ #ifndef TVM_OP_H_
#define TVM_OP_H_ #define TVM_OP_H_
#include <dmlc/registry.h>
#include <string> #include <string>
#include "./expr.h" #include "./expr.h"
...@@ -14,6 +15,8 @@ namespace tvm { ...@@ -14,6 +15,8 @@ namespace tvm {
/*! \brief binary operator */ /*! \brief binary operator */
class BinaryOp { class BinaryOp {
public: public:
// virtual destructor
virtual ~BinaryOp() {}
/*! \return the function name to be called in binary op */ /*! \return the function name to be called in binary op */
virtual const char* FunctionName() const = 0; virtual const char* FunctionName() const = 0;
/*! /*!
...@@ -23,6 +26,11 @@ class BinaryOp { ...@@ -23,6 +26,11 @@ class BinaryOp {
* \return the result expr * \return the result expr
*/ */
Expr operator()(Expr lhs, Expr rhs) const; Expr operator()(Expr lhs, Expr rhs) const;
/*!
* \brief get binary op by name
* \param name name of operator
*/
static const BinaryOp* Get(const char* name);
}; };
...@@ -37,6 +45,11 @@ class UnaryOp { ...@@ -37,6 +45,11 @@ class UnaryOp {
* \return the result expr * \return the result expr
*/ */
Expr operator()(Expr src) const; Expr operator()(Expr src) const;
/*!
* \brief get unary op by name
* \param name name of operator
*/
static const UnaryOp* Get(const char* name);
}; };
...@@ -45,7 +58,6 @@ class AddOp : public BinaryOp { ...@@ -45,7 +58,6 @@ class AddOp : public BinaryOp {
const char* FunctionName() const override { const char* FunctionName() const override {
return "+"; return "+";
} }
static AddOp* Get();
}; };
...@@ -54,7 +66,6 @@ class SubOp : public BinaryOp { ...@@ -54,7 +66,6 @@ class SubOp : public BinaryOp {
const char* FunctionName() const override { const char* FunctionName() const override {
return "-"; return "-";
} }
static SubOp* Get();
}; };
...@@ -63,7 +74,6 @@ class MulOp : public BinaryOp { ...@@ -63,7 +74,6 @@ class MulOp : public BinaryOp {
const char* FunctionName() const override { const char* FunctionName() const override {
return "*"; return "*";
} }
static MulOp* Get();
}; };
...@@ -72,7 +82,6 @@ class DivOp : public BinaryOp { ...@@ -72,7 +82,6 @@ class DivOp : public BinaryOp {
const char* FunctionName() const override { const char* FunctionName() const override {
return "/"; return "/";
} }
static DivOp* Get();
}; };
...@@ -81,7 +90,6 @@ class MaxOp : public BinaryOp { ...@@ -81,7 +90,6 @@ class MaxOp : public BinaryOp {
const char* FunctionName() const override { const char* FunctionName() const override {
return "max"; return "max";
} }
static MaxOp* Get();
}; };
...@@ -90,32 +98,57 @@ class MinOp : public BinaryOp { ...@@ -90,32 +98,57 @@ class MinOp : public BinaryOp {
const char* FunctionName() const override { const char* FunctionName() const override {
return "min"; return "min";
} }
static MinOp* Get();
}; };
#define DEFINE_OP_OVERLOAD(OpChar, OpName) \ #define DEFINE_BINARY_OP_OVERLOAD(OpChar) \
inline Expr operator OpChar (Expr lhs, Expr rhs) { \ inline Expr operator OpChar (Expr lhs, Expr rhs) { \
return (*OpName::Get())(lhs, rhs); \ static const BinaryOp* op = BinaryOp::Get(#OpChar); \
return (*op)(lhs, rhs); \
} }
#define DEFINE_BINARY_OP_FUNCTION(FuncName, OpName) \ #define DEFINE_BINARY_OP_FUNCTION(FuncName) \
inline Expr FuncName(Expr lhs, Expr rhs) { \ inline Expr FuncName(Expr lhs, Expr rhs) { \
return (*OpName::Get())(lhs, rhs); \ static const BinaryOp* op = BinaryOp::Get(#FuncName); \
return (*op)(lhs, rhs); \
} }
DEFINE_OP_OVERLOAD(+, AddOp); DEFINE_BINARY_OP_OVERLOAD(+);
DEFINE_OP_OVERLOAD(-, SubOp); DEFINE_BINARY_OP_OVERLOAD(-);
DEFINE_OP_OVERLOAD(*, MulOp); DEFINE_BINARY_OP_OVERLOAD(*);
DEFINE_OP_OVERLOAD(/, DivOp); DEFINE_BINARY_OP_OVERLOAD(/);
DEFINE_BINARY_OP_FUNCTION(max, MaxOp); DEFINE_BINARY_OP_FUNCTION(max);
DEFINE_BINARY_OP_FUNCTION(min, MinOp); DEFINE_BINARY_OP_FUNCTION(min);
// overload negation // overload negation
inline Expr operator-(Expr src) { inline Expr operator-(Expr src) {
return src * (-1); return src * (-1);
} }
// template of op registry
template<typename Op>
struct OpReg {
std::string name;
std::unique_ptr<Op> op;
inline OpReg& set(Op* op) {
this->op.reset(op);
return *this;
}
};
using UnaryOpReg = OpReg<UnaryOp>;
using BinaryOpReg = OpReg<BinaryOp>;
#define TVM_REGISTER_BINARY_OP(FunctionName, TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::BinaryOpReg & __make_ ## _BinOp_ ## TypeName = \
::dmlc::Registry<::tvm::BinaryOpReg>::Get()->__REGISTER_OR_GET__(#FunctionName) \
.set(new TypeName())
#define TVM_REGISTER_UNARY_OP(FunctionName, TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::BinaryOpReg & __make_ ## _BinOp_ ## TypeName = \
::dmlc::Registry<::tvm::UnaryOpReg>::Get()->__REGISTER_OR_GET__(#FunctionName) \
.set(new TypeName())
} // namespace tvm } // namespace tvm
#endif // TVM_OP_H_ #endif // TVM_OP_H_
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
from .function import binary_op
from ._function_internal import _binary_op
class Expr(NodeBase): class Expr(NodeBase):
pass def __add__(self, other):
return binary_op('+', self, other)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return binary_op('-', self, other)
def __rsub__(self, other):
return binary_op('-', other, self)
def __mul__(self, other):
return binary_op('*', self, other)
def __rmul__(self, other):
return binary_op('*', other, self)
def __div__(self, other):
return binary_op('/', self, other)
def __rdiv__(self, other):
return binary_op('/', other, self)
def __truediv__(self, other):
return self.__div__(other)
def __rtruediv__(self, other):
return self.__rdiv__(other)
def __neg__(self):
return self.__mul__(-1)
@register_node("VarNode") @register_node("VarNode")
class Var(Expr): class Var(Expr):
......
from __future__ import absolute_import as _abs
from numbers import Number as _Number
from ._ctypes._api import _init_function_module from ._ctypes._api import _init_function_module
import _function_internal from .import _function_internal
int32 = 1 int32 = 1
float32 = 2 float32 = 2
...@@ -18,4 +20,57 @@ def Var(name="tindex", dtype=int32): ...@@ -18,4 +20,57 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype) return _function_internal._Var(name, dtype)
def _symbol(value):
"""Convert a value to expression."""
if isinstance(value, _Number):
return constant(value)
else:
return value
def binary_op(op, lhs, rhs):
"""Binary operator given op lhs and rhs
Parameters
----------
op : str
The operator string
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return _function_internal._binary_op(op, _symbol(lhs), _symbol(rhs))
def max(lhs, rhs):
"""Max of two expressions
Parameters
----------
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return binary_op("max", lhs, rhs)
def min(lhs, rhs):
"""Min of two expressions
Parameters
----------
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return binary_op("max", lhs, rhs)
_init_function_module("tvm.cpp") _init_function_module("tvm.cpp")
...@@ -16,6 +16,7 @@ namespace tvm { ...@@ -16,6 +16,7 @@ namespace tvm {
using ArgStack = const std::vector<APIVariantValue>; using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue; using RetValue = APIVariantValue;
// expression logic x
TVM_REGISTER_API(_Var) TVM_REGISTER_API(_Var)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = Var(args.at(0), *ret = Var(args.at(0),
...@@ -24,21 +25,28 @@ TVM_REGISTER_API(_Var) ...@@ -24,21 +25,28 @@ TVM_REGISTER_API(_Var)
.add_argument("name", "str", "name of the var") .add_argument("name", "str", "name of the var")
.add_argument("dtype", "int", "data type of var"); .add_argument("dtype", "int", "data type of var");
TVM_REGISTER_API(constant)
TVM_REGISTER_API(max)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = max(args.at(0), args.at(1)); if (args.at(0).type_id == kLong) {
*ret = IntConstant(args.at(0));
} else if (args.at(0).type_id == kDouble) {
*ret = FloatConstant(args.at(0));
} else {
LOG(FATAL) << "only accept int or float";
}
}) })
.add_argument("lhs", "Expr", "left operand") .add_argument("src", "Number", "source number");
.add_argument("rhs", "Expr", "right operand");
TVM_REGISTER_API(min) TVM_REGISTER_API(_binary_op)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = min(args.at(0), args.at(1)); CHECK(args.at(0).type_id == kStr);
*ret = (*BinaryOp::Get(args.at(0).str.c_str()))(args.at(1), args.at(2));
}) })
.add_argument("op", "str", "operator")
.add_argument("lhs", "Expr", "left operand") .add_argument("lhs", "Expr", "left operand")
.add_argument("rhs", "Expr", "right operand"); .add_argument("rhs", "Expr", "right operand");
// transformations
TVM_REGISTER_API(format_str) TVM_REGISTER_API(format_str)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
std::ostringstream os; std::ostringstream os;
......
...@@ -5,6 +5,12 @@ ...@@ -5,6 +5,12 @@
#include <tvm/op.h> #include <tvm/op.h>
#include <tvm/expr_node.h> #include <tvm/expr_node.h>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::BinaryOpReg);
DMLC_REGISTRY_ENABLE(::tvm::UnaryOpReg);
} // namespace dmlc
namespace tvm { namespace tvm {
Expr BinaryOp::operator()(Expr lhs, Expr rhs) const { Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
...@@ -14,17 +20,18 @@ Expr BinaryOp::operator()(Expr lhs, Expr rhs) const { ...@@ -14,17 +20,18 @@ Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
return Expr(std::move(nptr)); return Expr(std::move(nptr));
} }
#define DEFINE_SINGLETON_GET(TypeName) \ const BinaryOp* BinaryOp::Get(const char* name) {
TypeName* TypeName::Get() { \ const auto* op = dmlc::Registry<BinaryOpReg>::Find(name);
static TypeName inst; \ CHECK(op != nullptr) << "cannot find " << name;
return &inst; \ return op->op.get();
} }
DEFINE_SINGLETON_GET(AddOp); TVM_REGISTER_BINARY_OP(+, AddOp);
DEFINE_SINGLETON_GET(SubOp); TVM_REGISTER_BINARY_OP(-, SubOp);
DEFINE_SINGLETON_GET(MulOp); TVM_REGISTER_BINARY_OP(*, MulOp);
DEFINE_SINGLETON_GET(DivOp); TVM_REGISTER_BINARY_OP(/, DivOp);
DEFINE_SINGLETON_GET(MaxOp); TVM_REGISTER_BINARY_OP(max, MaxOp);
DEFINE_SINGLETON_GET(MinOp); TVM_REGISTER_BINARY_OP(min, MinOp);
} // namespace tvm } // namespace tvm
...@@ -3,8 +3,8 @@ from tvm import cpp as tvm ...@@ -3,8 +3,8 @@ from tvm import cpp as tvm
def test_basic(): def test_basic():
a = tvm.Var('a') a = tvm.Var('a')
b = tvm.Var('b') b = tvm.Var('b')
z = tvm.max(a, b) c = a + b
assert tvm.format_str(z) == 'max(%s, %s)' % (a.name, b.name) assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
if __name__ == "__main__": if __name__ == "__main__":
test_basic() test_basic()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment