Commit ed99ddc7 by tqchen

Enable reduction in front-end

parent 7c550f11
Subproject commit f0deabe56bc20e60899e44b432d4a628a90161f3
Subproject commit 2b3ea8f5207152340014fd0a1ab12816ac48c326
......@@ -46,6 +46,7 @@ class RDomain : public NodeRef {
public:
/*! \brief constructor*/
RDomain() {}
explicit RDomain(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* constructor by domain
* \param domain The domain of reduction.
......
......@@ -48,4 +48,4 @@ struct Reduce : public ExprNode<Reduce> {
} // namespace ir
} // namespace tvm
#endif // TVM_IR_NODE_H_
#endif // TVM_IR_H_
......@@ -61,6 +61,9 @@ class NodeBase(object):
"""
self.handle = handle
def __repr__(self):
return _function_internal.format_str(self)
def __del__(self):
check_call(_LIB.TVMNodeFree(self.handle))
......
......@@ -22,3 +22,8 @@ class Range(NodeBase):
def __repr__(self):
return ('Range(min='+ str(self.min) +
', extent=' + str(self.extent) + ')')
@register_node
class RDomain(NodeBase):
pass
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import function as _func
from . import make as _make
class Expr(NodeBase):
def __repr__(self):
return _func.format_str(self)
def __add__(self, other):
return _make.Add(self, other)
......@@ -52,9 +48,14 @@ class CmpExpr(Expr):
class LogicalExpr(Expr):
pass
@register_node("Variable")
class Var(Expr):
pass
@register_node
class Reduce(Expr):
pass
@register_node
......
from __future__ import absolute_import as _abs
from numbers import Number as _Number, Integral as _Integral
from ._ctypes._api import _init_function_module
from .import _function_internal
from .import make as _make
from . import _function_internal
from . import make as _make
from . import expr as _expr
from . import collections as _collections
int32 = "int32"
float32 = "float32"
......@@ -76,4 +78,98 @@ def Tensor(shape, fcompute=None, dtype=None, name="TensorObj"):
shape, name, dtype, None, None)
def RDomain(dom):
"""Create a reduction domain given domain
Parameters
----------
dom : list of Range or list of pairs
The reduction domain.
Returns
-------
rdom : RDomain
The result rdomain
"""
if not isinstance(dom, (list, tuple)):
dom = [dom]
elif not isinstance(dom[0], (list, tuple)):
dom = [dom]
dnorm = []
for x in dom:
if isinstance(x, (list, tuple)):
if len(x) != 2:
raise ValueError("need to list of ranges")
dnorm.append(Range(x[0], x[1]))
else:
dnorm.append(x)
dnorm = convert(dnorm)
return _function_internal._RDomain(dnorm)
def sum(expr, rdom):
"""Create a sum expression over rdom
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
"""
assert isinstance(rdom, _collections.RDomain)
x = _make.Reduce("Add", expr, rdom)
return x
def sum(expr, rdom):
"""Create a sum expression over rdom
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
"""
assert isinstance(expr, _expr.Expr)
assert isinstance(rdom, _collections.RDomain)
x = _make.Reduce("Add", expr, rdom)
return x
def min(expr, rdom):
"""Create a min expression over rdom
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
"""
assert isinstance(expr, _expr.Expr)
assert isinstance(rdom, _collections.RDomain)
x = _make.Reduce("Min", expr, rdom)
return x
def max(expr, rdom):
"""Create a min expression over rdom
Parameters
----------
expr : Expr
The source expression.
rdom : RDomain
The reduction domainx
"""
assert isinstance(expr, _expr.Expr)
assert isinstance(rdom, _collections.RDomain)
x = _make.Reduce("Max", expr, rdom)
return x
_init_function_module("tvm")
......@@ -18,7 +18,7 @@ namespace tvm {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(format_str)
TVM_REGISTER_API(_format_str)
.set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::BaseExprNode;
using Halide::Internal::BaseStmtNode;
......
......@@ -4,11 +4,13 @@
* \file c_api_ir.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
namespace tvm {
using namespace tvm::ir;
using namespace Halide::Internal;
using ArgStack = const std::vector<APIVariantValue>;
......@@ -29,6 +31,12 @@ TVM_REGISTER_API(_make_For)
args.at(5));
});
TVM_REGISTER_API(_make_Reduce)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Reduce::make(args.at(0),
args.at(1),
args.at(2));
});
TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) {
......
......@@ -17,11 +17,22 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
namespace Halide {
namespace Internal {
using tvm::ir::Reduce;
template<>
void ExprNode<tvm::ir::Reduce>::accept(IRVisitor *v) const {
void ExprNode<Reduce>::accept(IRVisitor *v) const {
LOG(FATAL) << "Reduce do not work with IRVisitor yet";
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce("
<< op->op
<< ", ";
p->print(op->source);
p->stream << ", rdom=" << op->rdom << ")";
});
} // namespace Internal
} // namespace Halide
......@@ -31,7 +42,7 @@ namespace ir {
// reduce
TVM_REGISTER_NODE_TYPE(Reduce);
Expr make(std::string op, Expr source, RDomain rdom) {
Expr Reduce::make(std::string op, Expr source, RDomain rdom) {
auto n = std::make_shared<Reduce>();
CHECK(source.defined());
n->type = source.type();
......
......@@ -11,6 +11,17 @@ def test_tensor():
assert(tuple(T.shape) == (m, n, l))
assert(A.source is None)
def test_tensor_reduce():
m = tvm.Var('m')
n = tvm.Var('n')
l = tvm.Var('l')
A = tvm.Tensor((m, l), name='A')
B = tvm.Tensor((n, l), name='B')
T = tvm.Tensor((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
rd = tvm.RDomain(tvm.Range(A.shape[1]))
C = tvm.Tensor((m, n), lambda i, j: tvm.sum(T(i, j, rd.index[0]), rdom=rd))
print(tvm.format_str(C.source))
if __name__ == "__main__":
test_tensor()
test_tensor_reduce()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment