Commit 5445a936 by tqchen

Refactor to use iterVar

parent 7591714a
Subproject commit eb2f7d604a611318fc685172847bcf5ba2fcf835 Subproject commit e96ee0f2fb5239021c0facd5398a9a96644bc411
/*!
* Copyright (c) 2016 by Contributors
* \file domain.h
* \brief Defines the domain in AST
*/
#ifndef TVM_DOMAIN_H_
#define TVM_DOMAIN_H_
#include <ir/Range.h>
#include <memory>
#include "./base.h"
#include "./expr.h"
namespace tvm {
/*! \brief container class of reduction domain */
class RDomainNode;
class IterDomainNode;
/*!
* \brief same as Halide::IR::Range
* except it provide an constructor with (begin, end)
*
* \note Traditional Halide's Range have a constructor with
* (begin, extent), which does not match the convention in e.g. python.
* We decided to correct it by removing the constructor in HalideIR,
* and add it back in TVM's range.
*/
class Range : public Halide::IR::Range {
public:
/*! \brief constructor */
Range() {}
explicit Range(std::shared_ptr<Node> n) : Halide::IR::Range(n) {}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
Range(Expr begin, Expr end);
static Range make_with_min_extent(Expr min, Expr extent);
};
/*! \brief Domain is a multi-dimensional range */
using Domain = Array<Range>;
/*! \brief reduction domain */
class RDomain : public NodeRef {
public:
/*! \brief constructor*/
RDomain() {}
explicit RDomain(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* constructor by domain
* \param domain The domain of reduction.
*/
explicit RDomain(Domain domain);
/*!
* \brief constructor by list of ranges
* \param domain The reduction domain
*/
explicit RDomain(std::initializer_list<Range> domain)
: RDomain(Domain(domain)) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const RDomainNode* operator->() const;
/*! \return The dimension of the RDomain */
inline size_t ndim() const;
/*!
* \param i the index.
* \return i-th index variable in the RDomain
*/
inline Var index(size_t i) const;
/*! \return the 0-th index of the domain */
inline Var i0() const {
return index(0);
}
// low level constructor
static RDomain make(Array<Var> index, Domain domain);
};
/*! \brief use RDom as alias of RDomain */
using RDom = RDomain;
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional domain.
*/
class IterVarNode : public Node {
/*! \brief The */
Var var;
/*! \brief the domain of iteration */
Range dom;
/*! \brief additional tag on the iteration variable */
std::string tag;
};
/*! \brief reduction domain node */
class RDomainNode : public Node {
public:
/*! \brief internal index */
Array<Var> index;
/*! \brief The inernal domain */
Domain domain;
/*! \brief constructor */
RDomainNode() {}
RDomainNode(Array<Var> index, Domain domain)
: index(index), domain(domain) {
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("index", &index);
v->Visit("domain", &domain);
}
static constexpr const char* _type_key = "RDomain";
TVM_DECLARE_NODE_TYPE_INFO(RDomainNode);
};
inline const RDomainNode* RDomain::operator->() const {
return static_cast<const RDomainNode*>(node_.get());
}
inline size_t RDomain::ndim() const {
return (*this)->index.size();
}
inline Var RDomain::index(size_t i) const {
return (*this)->index[i];
}
// overload print function
inline std::ostream& operator<<(std::ostream &os, const RDomain& r){ // NOLINT(*)
os << "rdomain(" << r->domain << ")";
return os;
}
} // namespace tvm
#endif // TVM_DOMAIN_H_
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file expr.h * \file expr.h
* \brief Defines the expressions in AST. * \brief The Expr and related elements in DataFlow construction.
*/ */
#ifndef TVM_EXPR_H_ #ifndef TVM_EXPR_H_
#define TVM_EXPR_H_ #define TVM_EXPR_H_
#include <ir/Expr.h> #include <ir/Expr.h>
#include <ir/IRPrinter.h>
#include <ir/IROperator.h> #include <ir/IROperator.h>
#include <string> #include <string>
#include <algorithm>
#include "./base.h" #include "./base.h"
namespace tvm { namespace tvm {
...@@ -19,20 +21,14 @@ using Halide::Int; ...@@ -19,20 +21,14 @@ using Halide::Int;
using Halide::UInt; using Halide::UInt;
using Halide::Handle; using Halide::Handle;
// functions
using Halide::cast;
using Halide::min;
using Halide::max;
using Halide::abs;
using Halide::select;
using Halide::Expr; using Halide::Expr;
using Halide::VarExpr; using Halide::VarExpr;
using Halide::IR::FunctionRef; using Halide::IR::FunctionRef;
using Halide::IR::FunctionBaseNode; using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt; using Halide::Internal::Stmt;
using Halide::Internal::IRPrinter;
/*! \brief a named variable in TVM */
class Var : public Halide::VarExpr { class Var : public Halide::VarExpr {
public: public:
explicit Var(const std::string& name_hint = "v", explicit Var(const std::string& name_hint = "v",
...@@ -41,5 +37,134 @@ class Var : public Halide::VarExpr { ...@@ -41,5 +37,134 @@ class Var : public Halide::VarExpr {
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {} explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
}; };
/*! \brief container class of iteration variable. */
class IterVarNode;
/*!
* \brief same as Halide::IR::Range
* except it provide an constructor with (begin, end)
*
* \note Traditional Halide's Range have a constructor with
* (begin, extent), which does not match the convention in e.g. python.
* We decided to correct it by removing the constructor in HalideIR,
* and add it back in TVM's range.
*/
class Range : public Halide::IR::Range {
public:
/*! \brief constructor */
Range() {}
explicit Range(std::shared_ptr<Node> n) : Halide::IR::Range(n) {}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
Range(Expr begin, Expr end);
static Range make_with_min_extent(Expr min, Expr extent);
};
/*!
* \brief Iteration Variable,
* represents an iteration over an integer interval.
*/
class IterVar : public NodeRef {
public:
// construct a new iter var without a domain
IterVar() {}
// construct from shared ptr.
explicit IterVar(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief construction of iteration variable.
* \param dom The iteration domain.
* \param var_name The name of iteration variable.
* \param thread_tag The additional tag to indicate whether the var is binded to fixed-thread.
*/
explicit IterVar(Range dom, std::string var_name = "i", std::string thread_tag = "");
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IterVarNode* operator->() const;
/*!
* \return the corresponding var in the IterVar.
*/
inline operator Expr() const;
/*! \brief specify container node */
using ContainerType = IterVarNode;
};
using Domain = Array<Range>;
// functions
using Halide::cast;
using Halide::min;
using Halide::max;
using Halide::abs;
using Halide::select;
/*!
* \brief sum of of source expression over rdom
* \param source The source expression.
*/
Expr sum(Expr source, Array<IterVar> rdom);
/*!
* \brief max of of source expression over rdom
* \param source The source expression.
*/
Expr max(Expr source, Array<IterVar> rdom);
/*!
* \brief max of of source expression over rdom
* \param source The source expression.
*/
Expr min(Expr source, Array<IterVar> rdom);
// print functions for expr
std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
// definition of Node.
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional interval.
*/
class IterVarNode : public Node {
public:
/*! \brief The looping variable */
Var var;
/*!
* \brief the domain of iteration, if known, can be None
* For the intermediate schedule node, before schedule.
*/
Range dom;
/*!
* \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag.
*/
std::string thread_tag;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("dom", &dom);
v->Visit("thread_tag", &thread_tag);
}
static IterVar make(Var var, Range dom, std::string thread_tag);
static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode);
};
// inline implementations
inline const IterVarNode* IterVar::operator->() const {
return static_cast<const IterVarNode*>(node_.get());
}
inline IterVar::operator Expr() const {
return (*this)->var;
}
} // namespace tvm } // namespace tvm
#endif // TVM_EXPR_H_ #endif // TVM_EXPR_H_
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <type_traits> #include <type_traits>
#include <string> #include <string>
#include "./base.h" #include "./base.h"
#include "./domain.h" #include "./expr.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -30,11 +30,11 @@ struct Reduce : public ExprNode<Reduce> { ...@@ -30,11 +30,11 @@ struct Reduce : public ExprNode<Reduce> {
std::string op; std::string op;
/*! \brief The source operand */ /*! \brief The source operand */
Expr source; Expr source;
/*! \brief The reduction domain */ /*! \brief The reduction domains */
RDomain rdom; Array<IterVar> rdom;
/*! \brief construct expr from name and rdom */ /*! \brief construct expr from op and rdom */
static Expr make(std::string name, Expr src, RDomain rdom); static Expr make(std::string op, Expr src, Array<IterVar> rdom);
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type); v->Visit("dtype", &type);
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include <string> #include <string>
#include "./expr.h" #include "./expr.h"
#include "./domain.h"
#include "./tensor.h" #include "./tensor.h"
namespace tvm { namespace tvm {
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "./base.h" #include "./base.h"
#include "./expr.h" #include "./expr.h"
#include "./domain.h"
namespace tvm { namespace tvm {
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "./base.h" #include "./base.h"
#include "./expr.h" #include "./expr.h"
#include "./domain.h"
namespace tvm { namespace tvm {
...@@ -66,8 +65,8 @@ class Tensor : public FunctionRef { ...@@ -66,8 +65,8 @@ class Tensor : public FunctionRef {
* \return the result expression representing tensor read. * \return the result expression representing tensor read.
*/ */
Expr operator()(Array<Expr> indices) const; Expr operator()(Array<Expr> indices) const;
// overload print function /*! \brief specify container node */
friend std::ostream& operator<<(std::ostream &os, const Tensor& t); using ContainerType = TensorNode;
}; };
/*! \brief Operation that produces tensors */ /*! \brief Operation that produces tensors */
...@@ -87,6 +86,8 @@ class Operation : public NodeRef { ...@@ -87,6 +86,8 @@ class Operation : public NodeRef {
* \return The i-th output. * \return The i-th output.
*/ */
Tensor output(size_t i) const; Tensor output(size_t i) const;
/*! \brief specify container node */
using ContainerType = OperationNode;
}; };
/*! \brief Node to represent a tensor */ /*! \brief Node to represent a tensor */
...@@ -162,11 +163,5 @@ inline size_t Tensor::ndim() const { ...@@ -162,11 +163,5 @@ inline size_t Tensor::ndim() const {
return (*this)->shape.size(); return (*this)->shape.size();
} }
inline std::ostream& operator<<(std::ostream &os, const Tensor& t) { // NOLINT(*)
os << "Tensor(shape=" << t->shape
<< ", name=" << t->name << ')';
return os;
}
} // namespace tvm } // namespace tvm
#endif // TVM_TENSOR_H_ #endif // TVM_TENSOR_H_
...@@ -118,6 +118,7 @@ def convert(value): ...@@ -118,6 +118,7 @@ def convert(value):
raise ValueError("don't know how to handle type %s" % type(value)) raise ValueError("don't know how to handle type %s" % type(value))
return value return value
def _push_arg(arg): def _push_arg(arg):
a = ArgVariant() a = ArgVariant()
if arg is None: if arg is None:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
from . import _function_internal from . import _function_internal
from . import expr as _expr
@register_node @register_node
class Array(NodeBase): class Array(NodeBase):
...@@ -19,11 +20,9 @@ class Array(NodeBase): ...@@ -19,11 +20,9 @@ class Array(NodeBase):
@register_node @register_node
class Range(NodeBase): class Range(NodeBase):
def __repr__(self): pass
return ('Range(min='+ str(self.min) +
', extent=' + str(self.extent) + ')')
@register_node @register_node
class RDomain(NodeBase): class IterVar(_expr.ExprCompatible):
pass pass
...@@ -2,7 +2,7 @@ from __future__ import absolute_import as _abs ...@@ -2,7 +2,7 @@ from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
from . import make as _make from . import make as _make
class Expr(NodeBase): class ExprCompatible(NodeBase):
def __add__(self, other): def __add__(self, other):
return _make.Add(self, other) return _make.Add(self, other)
...@@ -36,6 +36,10 @@ class Expr(NodeBase): ...@@ -36,6 +36,10 @@ class Expr(NodeBase):
def __neg__(self): def __neg__(self):
return self.__mul__(-1) return self.__mul__(-1)
class Expr(ExprCompatible):
pass
class ConstExpr(Expr): class ConstExpr(Expr):
pass pass
......
...@@ -103,33 +103,34 @@ def compute(shape, fcompute, name="TensorCompute"): ...@@ -103,33 +103,34 @@ def compute(shape, fcompute, name="TensorCompute"):
shape, name, body.dtype, op_node, 0) shape, name, body.dtype, op_node, 0)
def RDomain(dom): def IterVar(dom, name='iter', thread_tag=''):
"""Create a reduction domain given domain """Create a iteration variable
Parameters Parameters
---------- ----------
dom : list of Range or list of pairs dom : Range
The reduction domain. The domain of iteration.
name : str
The name of iteration variable.
thread_tag : str
The thread tag of the iteration variable.
Returns Returns
------- -------
rdom : RDomain iter_var : IterVar
The result rdomain The result itervar
""" """
if not isinstance(dom, (list, tuple)): if isinstance(dom, (list, tuple)):
dom = [dom] if len(dom) != 2:
elif not isinstance(dom[0], (list, tuple)): raise ValueError("need to list of ranges")
dom = [dom] dom = Range(dom[0], dom[1])
dnorm = []
for x in dom: if not isinstance(dom, _collections.Range):
if isinstance(x, (list, tuple)): raise ValueError("dom need to be Range")
if len(x) != 2:
raise ValueError("need to list of ranges") return _function_internal._IterVar(dom, name, thread_tag)
dnorm.append(Range(x[0], x[1]))
else:
dnorm.append(x)
dnorm = convert(dnorm)
return _function_internal._RDomain(dnorm)
def sum(expr, rdom): def sum(expr, rdom):
...@@ -143,10 +144,11 @@ def sum(expr, rdom): ...@@ -143,10 +144,11 @@ def sum(expr, rdom):
rdom : RDomain rdom : RDomain
The reduction domainx The reduction domainx
""" """
assert isinstance(rdom, _collections.RDomain) rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Add", expr, rdom) x = _make.Reduce("Add", expr, rdom)
return x return x
def min(expr, rdom): def min(expr, rdom):
"""Create a min expression over rdom """Create a min expression over rdom
...@@ -158,11 +160,11 @@ def min(expr, rdom): ...@@ -158,11 +160,11 @@ def min(expr, rdom):
rdom : RDomain rdom : RDomain
The reduction domainx The reduction domainx
""" """
assert isinstance(expr, _expr.Expr) rdom = rdom if isinstance(rdom, list) else [rdom]
assert isinstance(rdom, _collections.RDomain)
x = _make.Reduce("Min", expr, rdom) x = _make.Reduce("Min", expr, rdom)
return x return x
def max(expr, rdom): def max(expr, rdom):
"""Create a min expression over rdom """Create a min expression over rdom
...@@ -174,8 +176,7 @@ def max(expr, rdom): ...@@ -174,8 +176,7 @@ def max(expr, rdom):
rdom : RDomain rdom : RDomain
The reduction domainx The reduction domainx
""" """
assert isinstance(expr, _expr.Expr) rdom = rdom if isinstance(rdom, list) else [rdom]
assert isinstance(rdom, _collections.RDomain)
x = _make.Reduce("Max", expr, rdom) x = _make.Reduce("Max", expr, rdom)
return x return x
......
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node, convert
from . import collections as _collections
from . import make as _make from . import make as _make
from . import expr as _expr from . import expr as _expr
...@@ -10,7 +11,18 @@ class Tensor(NodeBase): ...@@ -10,7 +11,18 @@ class Tensor(NodeBase):
ndim = self.ndim ndim = self.ndim
if len(indices) != ndim: if len(indices) != ndim:
raise ValueError("Need to provide %d index in tensor slice" % ndim) raise ValueError("Need to provide %d index in tensor slice" % ndim)
return _make.Call(self.dtype, self.name, indices, _expr.Call.Halide, self, 0) indices = convert(indices)
args = []
for x in indices:
if isinstance(x, _collections.IterVar):
args.append(x.var)
elif isinstance(x, _expr.Expr):
args.append(x)
else:
raise ValueError("The indices must be expression")
return _make.Call(self.dtype, self.name, args, _expr.Call.Halide, self, 0)
@property @property
def ndim(self): def ndim(self):
......
...@@ -4,9 +4,7 @@ ...@@ -4,9 +4,7 @@
* \file c_api_impl.cc * \file c_api_impl.cc
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/domain.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
namespace dmlc { namespace dmlc {
...@@ -22,21 +20,9 @@ TVM_REGISTER_API(_format_str) ...@@ -22,21 +20,9 @@ TVM_REGISTER_API(_format_str)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::BaseExprNode; using Halide::Internal::BaseExprNode;
using Halide::Internal::BaseStmtNode; using Halide::Internal::BaseStmtNode;
CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os; std::ostringstream os;
auto& sptr = args.at(0).sptr; os << args.at(0).operator NodeRef();
if (dynamic_cast<const TensorNode*>(sptr.get())) {
os << args.at(0).operator Tensor();
} else if (dynamic_cast<const RDomainNode*>(sptr.get())) {
os << args.at(0).operator RDomain();
} else if (dynamic_cast<const BaseExprNode*>(sptr.get())) {
os << args.at(0).operator Expr();
} else if (dynamic_cast<const BaseStmtNode*>(sptr.get())) {
os << args.at(0).operator Stmt();
} else {
LOG(FATAL) << "don't know how to print input NodeBaseType";
}
*ret = os.str(); *ret = os.str();
}) })
.add_argument("expr", "Node", "expression to be printed"); .add_argument("expr", "Node", "expression to be printed");
......
...@@ -5,10 +5,8 @@ ...@@ -5,10 +5,8 @@
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/domain.h>
#include <tvm/split.h> #include <tvm/split.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
namespace tvm { namespace tvm {
...@@ -95,11 +93,13 @@ TVM_REGISTER_API(_ComputeOp) ...@@ -95,11 +93,13 @@ TVM_REGISTER_API(_ComputeOp)
args.at(3)); args.at(3));
}); });
TVM_REGISTER_API(_RDomain)
TVM_REGISTER_API(_IterVar)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = RDomain(args.at(0).operator Domain()); *ret = IterVar(args.at(0), args.at(1), args.at(2));
}); });
TVM_REGISTER_API(_DimSplit) TVM_REGISTER_API(_DimSplit)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = DimSplitNode::make(args.at(0), args.at(1)); *ret = DimSplitNode::make(args.at(0), args.at(1));
......
...@@ -125,7 +125,13 @@ class APIVariantValue { ...@@ -125,7 +125,13 @@ class APIVariantValue {
return Expr(static_cast<float>(operator double())); return Expr(static_cast<float>(operator double()));
} }
CHECK_EQ(type_id, kNodeHandle); CHECK_EQ(type_id, kNodeHandle);
return Expr(sptr); if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
} else {
CHECK(dynamic_cast<typename Expr::ContainerType*>(sptr.get()))
<< "did not pass in Expr in a place need Expr";
return Expr(sptr);
}
} }
inline operator double() const { inline operator double() const {
CHECK_EQ(type_id, kDouble); CHECK_EQ(type_id, kDouble);
......
/*!
* Copyright (c) 2016 by Contributors
* \file domain.cc
*/
#include <tvm/base.h>
#include <tvm/domain.h>
namespace tvm {
Range::Range(Expr begin, Expr end)
: Range(std::make_shared<Halide::IR::RangeNode>(begin, end - begin)) {
// TODO(tqchen) add simplify to end - begin
}
Range Range::make_with_min_extent(Expr min, Expr extent) {
return Range(std::make_shared<Halide::IR::RangeNode>(min, extent));
}
RDomain::RDomain(Domain domain) {
std::vector<Var> index;
for (size_t i = 0; i < domain.size(); ++i) {
std::ostringstream os;
os << "reduction_index" << i;
index.push_back(Var(os.str()));
}
Array<Var> idx(index);
node_ = std::make_shared<RDomainNode>(
std::move(idx), std::move(domain));
}
RDomain RDomain::make(Array<Var> index, Domain domain) {
return RDomain(std::make_shared<RDomainNode>(index, domain));
}
TVM_REGISTER_NODE_TYPE(RDomainNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file expr.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <ir/IRPrinter.h>
#include <memory>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc
namespace tvm {
Range::Range(Expr begin, Expr end)
: Range(std::make_shared<Halide::IR::RangeNode>(begin, end - begin)) {
// TODO(tqchen) add simplify to end - begin
}
Range Range::make_with_min_extent(Expr min, Expr extent) {
return Range(std::make_shared<Halide::IR::RangeNode>(min, extent));
}
IterVar::IterVar(Range dom, std::string var_name, std::string thread_tag)
: IterVar(IterVarNode::make(Var(var_name, Int(32)), dom, thread_tag)) {}
IterVar IterVarNode::make(Var var, Range dom, std::string thread_tag) {
std::shared_ptr<IterVarNode> n = std::make_shared<IterVarNode>();
n->var = var;
n->dom = dom;
n->thread_tag = thread_tag;
return IterVar(n);
}
Expr sum(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Add", source, rdom);
}
Expr max(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Max", source, rdom);
}
Expr min(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Min", source, rdom);
}
std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
IRPrinter(os).print(n);
return os;
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarNode>([](const IterVarNode *op, IRPrinter *p) {
p->stream << "iter_var(";
if (op->var->name_hint.length() != 0) {
p->stream << op->var->name_hint << ", ";
}
p->stream << op->dom;
if (op->thread_tag.length() != 0) {
p->stream << ", " << op->thread_tag;
}
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Halide::IR::RangeNode>([](const Halide::IR::RangeNode *op, IRPrinter *p) {
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
TVM_REGISTER_NODE_TYPE(IterVarNode);
} // namespace tvm
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file ir_node.cc * \file ir.cc
*/ */
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/expr.h> #include <tvm/expr.h>
...@@ -9,11 +9,6 @@ ...@@ -9,11 +9,6 @@
#include <ir/IRPrinter.h> #include <ir/IRPrinter.h>
#include <memory> #include <memory>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc
namespace Halide { namespace Halide {
namespace Internal { namespace Internal {
...@@ -53,9 +48,12 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -53,9 +48,12 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace tvm { namespace tvm {
namespace ir { namespace ir {
Expr Reduce::make(std::string op, Expr source, RDomain rdom) { Expr Reduce::make(std::string op, Expr source, Array<IterVar> rdom) {
auto n = std::make_shared<Reduce>(); auto n = std::make_shared<Reduce>();
CHECK(source.defined()); CHECK(source.defined());
for (size_t i = 0; i < rdom.size(); ++i) {
CHECK(rdom[i].defined());
}
n->type = source.type(); n->type = source.type();
n->source = source; n->source = source;
n->op = op; n->op = op;
......
...@@ -41,6 +41,12 @@ Tensor TensorNode::make(Array<Expr> shape, ...@@ -41,6 +41,12 @@ Tensor TensorNode::make(Array<Expr> shape,
return Tensor(n); return Tensor(n);
} }
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) {
p->stream << "Tensor(shape=" << t->shape
<< ", name=" << t->name << ')';
});
TVM_REGISTER_NODE_TYPE(TensorNode); TVM_REGISTER_NODE_TYPE(TensorNode);
} // namespace tvm } // namespace tvm
...@@ -42,27 +42,29 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) { ...@@ -42,27 +42,29 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
} }
} }
inline RDomain MutateRDom(RDomain rdom, IRMutator *m) { inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) {
std::vector<Range> new_dom(rdom->domain.size()); std::vector<IterVar> new_dom(rdom.size());
bool changed = false; bool changed = false;
for (size_t i = 0; i < rdom->domain.size(); i++) { for (size_t i = 0; i < rdom.size(); i++) {
Range r = rdom->domain[i]; IterVar v = rdom[i];
Range r = v->dom;
Expr new_min = m->Mutate(r->min); Expr new_min = m->Mutate(r->min);
Expr new_extent = m->Mutate(r->extent); Expr new_extent = m->Mutate(r->extent);
if (!r->min.same_as(new_min)) changed = true; if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true; if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = Range::make_with_min_extent(new_min, new_extent); new_dom[i] = IterVarNode::make(
v->var, Range::make_with_min_extent(new_min, new_extent), v->thread_tag);
} }
if (!changed) { if (!changed) {
return rdom; return rdom;
} else { } else {
return RDomain::make(rdom->index, Domain(new_dom)); return Array<IterVar>(new_dom);
} }
} }
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) { .set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
RDomain new_rdom = MutateRDom(op->rdom, m); Array<IterVar> new_rdom = MutateRDom(op->rdom, m);
Expr new_source = m->Mutate(op->source); Expr new_source = m->Mutate(op->source);
if (op->rdom.same_as(new_rdom) && if (op->rdom.same_as(new_rdom) &&
op->source.same_as(new_source)) { op->source.same_as(new_source)) {
......
...@@ -45,15 +45,15 @@ using namespace Halide::Internal; ...@@ -45,15 +45,15 @@ using namespace Halide::Internal;
void NoOp(const NodeRef& n, IRVisitor* v) { void NoOp(const NodeRef& n, IRVisitor* v) {
} }
inline void VisitArray(Array<Expr> arr, IRVisitor* v) { inline void VisitArray(const Array<Expr>& arr, IRVisitor* v) {
for (size_t i = 0; i < arr.size(); i++) { for (size_t i = 0; i < arr.size(); i++) {
v->Visit(arr[i]); v->Visit(arr[i]);
} }
} }
inline void VisitRDom(RDomain rdom, IRVisitor* v) { inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {
for (size_t i = 0; i < rdom->domain.size(); i++) { for (size_t i = 0; i < rdom.size(); i++) {
Range r = rdom->domain[i]; Range r = rdom[i]->dom;
v->Visit(r->min); v->Visit(r->min);
v->Visit(r->extent); v->Visit(r->extent);
} }
......
...@@ -67,7 +67,6 @@ void MakeLoop(const DimSplitNode* op, ...@@ -67,7 +67,6 @@ void MakeLoop(const DimSplitNode* op,
Stmt MakePipeline(const Schedule& sch, Stmt body) { Stmt MakePipeline(const Schedule& sch, Stmt body) {
return body; return body;
} }
......
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/tvm.h> #include <tvm/tvm.h>
...@@ -14,6 +13,19 @@ TEST(Tensor, Basic) { ...@@ -14,6 +13,19 @@ TEST(Tensor, Basic) {
}, "C"); }, "C");
} }
TEST(Tensor, Reduce) {
using namespace tvm;
Var m("m"), n("n"), l("l");
Tensor A({m, l}, "A");
Tensor B({n, l}, "B");
IterVar rv(Range{0, l}, "k");
auto C = Compute({m, n}, [&](Var i, Var j) {
return sum(max(A(i, rv) * B(j, rv), 1), {rv});
}, "C");
LOG(INFO) << C->op.as<ComputeOpNode>()->body;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
...@@ -7,7 +7,7 @@ def test_tensor(): ...@@ -7,7 +7,7 @@ def test_tensor():
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
B = tvm.placeholder((n, l), name='B') B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
print(T)
print(T.op.body) print(T.op.body)
assert(tuple(T.shape) == (m, n, l)) assert(tuple(T.shape) == (m, n, l))
assert(A.source is None) assert(A.source is None)
...@@ -19,8 +19,8 @@ def test_tensor_reduce(): ...@@ -19,8 +19,8 @@ def test_tensor_reduce():
A = tvm.placeholder((m, l), name='A') A = tvm.placeholder((m, l), name='A')
B = tvm.placeholder((n, l), name='B') B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
rd = tvm.RDomain(tvm.Range(A.shape[1])) rv = tvm.IterVar((0, A.shape[1]), name="k")
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rd.index[0]), rdom=rd)) C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), rdom=rv))
print(C.op.body) print(C.op.body)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment