Commit 8de0a083 by tqchen

[OP] enable binary op

parent 1a7fb9f9
......@@ -6,6 +6,7 @@
#ifndef TVM_OP_H_
#define TVM_OP_H_
#include <dmlc/registry.h>
#include <string>
#include "./expr.h"
......@@ -14,6 +15,8 @@ namespace tvm {
/*! \brief binary operator */
class BinaryOp {
public:
// virtual destructor
virtual ~BinaryOp() {}
/*! \return the function name to be called in binary op */
virtual const char* FunctionName() const = 0;
/*!
......@@ -23,6 +26,11 @@ class BinaryOp {
* \return the result expr
*/
Expr operator()(Expr lhs, Expr rhs) const;
/*!
* \brief get binary op by name
* \param name name of operator
*/
static const BinaryOp* Get(const char* name);
};
......@@ -37,6 +45,11 @@ class UnaryOp {
* \return the result expr
*/
Expr operator()(Expr src) const;
/*!
* \brief get unary op by name
* \param name name of operator
*/
static const UnaryOp* Get(const char* name);
};
......@@ -45,7 +58,6 @@ class AddOp : public BinaryOp {
const char* FunctionName() const override {
return "+";
}
static AddOp* Get();
};
......@@ -54,7 +66,6 @@ class SubOp : public BinaryOp {
const char* FunctionName() const override {
return "-";
}
static SubOp* Get();
};
......@@ -63,7 +74,6 @@ class MulOp : public BinaryOp {
const char* FunctionName() const override {
return "*";
}
static MulOp* Get();
};
......@@ -72,7 +82,6 @@ class DivOp : public BinaryOp {
const char* FunctionName() const override {
return "/";
}
static DivOp* Get();
};
......@@ -81,7 +90,6 @@ class MaxOp : public BinaryOp {
const char* FunctionName() const override {
return "max";
}
static MaxOp* Get();
};
......@@ -90,32 +98,57 @@ class MinOp : public BinaryOp {
const char* FunctionName() const override {
return "min";
}
static MinOp* Get();
};
#define DEFINE_OP_OVERLOAD(OpChar, OpName) \
#define DEFINE_BINARY_OP_OVERLOAD(OpChar) \
inline Expr operator OpChar (Expr lhs, Expr rhs) { \
return (*OpName::Get())(lhs, rhs); \
static const BinaryOp* op = BinaryOp::Get(#OpChar); \
return (*op)(lhs, rhs); \
}
#define DEFINE_BINARY_OP_FUNCTION(FuncName, OpName) \
inline Expr FuncName(Expr lhs, Expr rhs) { \
return (*OpName::Get())(lhs, rhs); \
#define DEFINE_BINARY_OP_FUNCTION(FuncName) \
inline Expr FuncName(Expr lhs, Expr rhs) { \
static const BinaryOp* op = BinaryOp::Get(#FuncName); \
return (*op)(lhs, rhs); \
}
DEFINE_OP_OVERLOAD(+, AddOp);
DEFINE_OP_OVERLOAD(-, SubOp);
DEFINE_OP_OVERLOAD(*, MulOp);
DEFINE_OP_OVERLOAD(/, DivOp);
DEFINE_BINARY_OP_OVERLOAD(+);
DEFINE_BINARY_OP_OVERLOAD(-);
DEFINE_BINARY_OP_OVERLOAD(*);
DEFINE_BINARY_OP_OVERLOAD(/);
DEFINE_BINARY_OP_FUNCTION(max, MaxOp);
DEFINE_BINARY_OP_FUNCTION(min, MinOp);
DEFINE_BINARY_OP_FUNCTION(max);
DEFINE_BINARY_OP_FUNCTION(min);
// overload negation
inline Expr operator-(Expr src) {
return src * (-1);
}
// template of op registry
template<typename Op>
struct OpReg {
std::string name;
std::unique_ptr<Op> op;
inline OpReg& set(Op* op) {
this->op.reset(op);
return *this;
}
};
using UnaryOpReg = OpReg<UnaryOp>;
using BinaryOpReg = OpReg<BinaryOp>;
#define TVM_REGISTER_BINARY_OP(FunctionName, TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::BinaryOpReg & __make_ ## _BinOp_ ## TypeName = \
::dmlc::Registry<::tvm::BinaryOpReg>::Get()->__REGISTER_OR_GET__(#FunctionName) \
.set(new TypeName())
#define TVM_REGISTER_UNARY_OP(FunctionName, TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::BinaryOpReg & __make_ ## _BinOp_ ## TypeName = \
::dmlc::Registry<::tvm::UnaryOpReg>::Get()->__REGISTER_OR_GET__(#FunctionName) \
.set(new TypeName())
} // namespace tvm
#endif // TVM_OP_H_
from ._ctypes._api import NodeBase, register_node
from .function import binary_op
from ._function_internal import _binary_op
class Expr(NodeBase):
pass
def __add__(self, other):
return binary_op('+', self, other)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return binary_op('-', self, other)
def __rsub__(self, other):
return binary_op('-', other, self)
def __mul__(self, other):
return binary_op('*', self, other)
def __rmul__(self, other):
return binary_op('*', other, self)
def __div__(self, other):
return binary_op('/', self, other)
def __rdiv__(self, other):
return binary_op('/', other, self)
def __truediv__(self, other):
return self.__div__(other)
def __rtruediv__(self, other):
return self.__rdiv__(other)
def __neg__(self):
return self.__mul__(-1)
@register_node("VarNode")
class Var(Expr):
......
from __future__ import absolute_import as _abs
from numbers import Number as _Number
from ._ctypes._api import _init_function_module
import _function_internal
from .import _function_internal
int32 = 1
float32 = 2
......@@ -18,4 +20,57 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype)
def _symbol(value):
"""Convert a value to expression."""
if isinstance(value, _Number):
return constant(value)
else:
return value
def binary_op(op, lhs, rhs):
"""Binary operator given op lhs and rhs
Parameters
----------
op : str
The operator string
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return _function_internal._binary_op(op, _symbol(lhs), _symbol(rhs))
def max(lhs, rhs):
"""Max of two expressions
Parameters
----------
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return binary_op("max", lhs, rhs)
def min(lhs, rhs):
"""Min of two expressions
Parameters
----------
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return binary_op("max", lhs, rhs)
_init_function_module("tvm.cpp")
......@@ -16,6 +16,7 @@ namespace tvm {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
// expression logic x
TVM_REGISTER_API(_Var)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Var(args.at(0),
......@@ -24,21 +25,28 @@ TVM_REGISTER_API(_Var)
.add_argument("name", "str", "name of the var")
.add_argument("dtype", "int", "data type of var");
TVM_REGISTER_API(max)
TVM_REGISTER_API(constant)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = max(args.at(0), args.at(1));
if (args.at(0).type_id == kLong) {
*ret = IntConstant(args.at(0));
} else if (args.at(0).type_id == kDouble) {
*ret = FloatConstant(args.at(0));
} else {
LOG(FATAL) << "only accept int or float";
}
})
.add_argument("lhs", "Expr", "left operand")
.add_argument("rhs", "Expr", "right operand");
.add_argument("src", "Number", "source number");
TVM_REGISTER_API(min)
TVM_REGISTER_API(_binary_op)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = min(args.at(0), args.at(1));
CHECK(args.at(0).type_id == kStr);
*ret = (*BinaryOp::Get(args.at(0).str.c_str()))(args.at(1), args.at(2));
})
.add_argument("op", "str", "operator")
.add_argument("lhs", "Expr", "left operand")
.add_argument("rhs", "Expr", "right operand");
// transformations
TVM_REGISTER_API(format_str)
.set_body([](const ArgStack& args, RetValue *ret) {
std::ostringstream os;
......
......@@ -5,6 +5,12 @@
#include <tvm/op.h>
#include <tvm/expr_node.h>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::BinaryOpReg);
DMLC_REGISTRY_ENABLE(::tvm::UnaryOpReg);
} // namespace dmlc
namespace tvm {
Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
......@@ -14,17 +20,18 @@ Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
return Expr(std::move(nptr));
}
#define DEFINE_SINGLETON_GET(TypeName) \
TypeName* TypeName::Get() { \
static TypeName inst; \
return &inst; \
}
DEFINE_SINGLETON_GET(AddOp);
DEFINE_SINGLETON_GET(SubOp);
DEFINE_SINGLETON_GET(MulOp);
DEFINE_SINGLETON_GET(DivOp);
DEFINE_SINGLETON_GET(MaxOp);
DEFINE_SINGLETON_GET(MinOp);
const BinaryOp* BinaryOp::Get(const char* name) {
const auto* op = dmlc::Registry<BinaryOpReg>::Find(name);
CHECK(op != nullptr) << "cannot find " << name;
return op->op.get();
}
TVM_REGISTER_BINARY_OP(+, AddOp);
TVM_REGISTER_BINARY_OP(-, SubOp);
TVM_REGISTER_BINARY_OP(*, MulOp);
TVM_REGISTER_BINARY_OP(/, DivOp);
TVM_REGISTER_BINARY_OP(max, MaxOp);
TVM_REGISTER_BINARY_OP(min, MinOp);
} // namespace tvm
......@@ -3,8 +3,8 @@ from tvm import cpp as tvm
def test_basic():
a = tvm.Var('a')
b = tvm.Var('b')
z = tvm.max(a, b)
assert tvm.format_str(z) == 'max(%s, %s)' % (a.name, b.name)
c = a + b
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
if __name__ == "__main__":
test_basic()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment