Commit d3ee03eb by tqchen

expose range

parent 56e10eb0
...@@ -89,7 +89,7 @@ class Var : public Expr { ...@@ -89,7 +89,7 @@ class Var : public Expr {
}; };
Expr IntConstant(int64_t value); Expr IntConstant(int64_t value);
Expr FloatConstant(int64_t value); Expr FloatConstant(double value);
/*! \brief base of expression node */ /*! \brief base of expression node */
class ExprNode : public Node { class ExprNode : public Node {
......
...@@ -40,6 +40,18 @@ inline void Visit(const Expr& expr, FVisit fvisit) { ...@@ -40,6 +40,18 @@ inline void Visit(const Expr& expr, FVisit fvisit) {
Visit(n->src, fvisit); Visit(n->src, fvisit);
break; break;
} }
case kReduceNode: {
const auto* n = expr.Get<ReduceNode>();
Visit(n->src, fvisit);
break;
}
case kTensorReadNode: {
const auto* n = expr.Get<TensorReadNode>();
for (size_t i = 0; i < n->indices.size(); ++i) {
Visit(n->indices[i], fvisit);
}
break;
}
default: break; default: break;
} }
fvisit(expr); fvisit(expr);
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define TVM_TENSOR_H_ #define TVM_TENSOR_H_
#include <string> #include <string>
#include <vector>
#include <type_traits> #include <type_traits>
#include "./expr.h" #include "./expr.h"
#include "./array.h" #include "./array.h"
...@@ -46,17 +47,17 @@ class TensorNode : public Node { ...@@ -46,17 +47,17 @@ class TensorNode : public Node {
using FCompute = std::function<Expr (const Array<Var>& i)>; using FCompute = std::function<Expr (const Array<Var>& i)>;
// converters from other functions into fcompute // converters from other functions into fcompute
inline FCompute GetFCompute(std::function<Expr (Var x)> f) { inline FCompute GetFCompute(std::function<Expr(Var x)> f) {
return [f](const Array<Var>& i) { return f(i[0]); }; return [f] (const Array<Var>& i) { return f(i[0]); };
} }
inline FCompute GetFCompute(std::function<Expr (Var, Var)> f) { inline FCompute GetFCompute(std::function<Expr(Var, Var)> f) {
return [f](const Array<Var>& i) { return f(i[0], i[1]); }; return [f] (const Array<Var>& i) { return f(i[0], i[1]); };
} }
inline FCompute GetFCompute(std::function<Expr (Var, Var, Var)> f) { inline FCompute GetFCompute(std::function<Expr(Var, Var, Var)> f) {
return [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); }; return [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
} }
inline FCompute GetFCompute(std::function<Expr (Var, Var, Var, Var)> f) { inline FCompute GetFCompute(std::function<Expr(Var, Var, Var, Var)> f) {
return [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); }; return [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
} }
/*! /*!
...@@ -132,6 +133,10 @@ class Tensor : public NodeRef { ...@@ -132,6 +133,10 @@ class Tensor : public NodeRef {
* \return the result expression representing tensor read. * \return the result expression representing tensor read.
*/ */
Expr operator()(Array<Expr> indices) const; Expr operator()(Array<Expr> indices) const;
/*! \return list of input tensors to this tensor */
std::vector<Tensor> InputTensors() const;
/*! \return whether the tensor stores a result of reduction */
bool IsRTensor() const;
// printt function // printt function
friend std::ostream& operator<<(std::ostream &os, const Tensor& t) { // NOLINT(*) friend std::ostream& operator<<(std::ostream &os, const Tensor& t) { // NOLINT(*)
os << "Tensor(shape=" << t.shape() os << "Tensor(shape=" << t.shape()
......
from ._ctypes._api import NodeBase, register_node
@register_node("RangeNode")
class Range(NodeBase):
pass
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
from .function import binary_op from .function import binary_op
from ._function_internal import _binary_op
class Expr(NodeBase): class Expr(NodeBase):
def __add__(self, other): def __add__(self, other):
......
...@@ -28,23 +28,6 @@ def _symbol(value): ...@@ -28,23 +28,6 @@ def _symbol(value):
return value return value
def binary_op(op, lhs, rhs):
"""Binary operator given op lhs and rhs
Parameters
----------
op : str
The operator string
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return _function_internal._binary_op(op, _symbol(lhs), _symbol(rhs))
def max(lhs, rhs): def max(lhs, rhs):
"""Max of two expressions """Max of two expressions
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/op.h> #include <tvm/op.h>
#include <tvm/tensor.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
namespace dmlc { namespace dmlc {
...@@ -37,7 +38,7 @@ TVM_REGISTER_API(constant) ...@@ -37,7 +38,7 @@ TVM_REGISTER_API(constant)
}) })
.add_argument("src", "Number", "source number"); .add_argument("src", "Number", "source number");
TVM_REGISTER_API(_binary_op) TVM_REGISTER_API(binary_op)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kStr); CHECK(args.at(0).type_id == kStr);
*ret = (*BinaryOp::Get(args.at(0).str.c_str()))(args.at(1), args.at(2)); *ret = (*BinaryOp::Get(args.at(0).str.c_str()))(args.at(1), args.at(2));
...@@ -53,11 +54,36 @@ TVM_REGISTER_API(_raw_ptr) ...@@ -53,11 +54,36 @@ TVM_REGISTER_API(_raw_ptr)
}) })
.add_argument("src", "NodeBase", "the node base"); .add_argument("src", "NodeBase", "the node base");
TVM_REGISTER_API(Range)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Range(args.at(0), args.at(1));
})
.add_argument("begin", "Expr", "beginning of the range.")
.add_argument("end", "Expr", "end of the range");
TVM_REGISTER_API(_TensorInput)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Tensor(
static_cast<Array<Expr> >(args.at(0)),
static_cast<std::string>(args.at(1)),
static_cast<DataType>(static_cast<int>(args.at(1))));
});
// transformations // transformations
TVM_REGISTER_API(format_str) TVM_REGISTER_API(format_str)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os; std::ostringstream os;
os << Expr(args.at(0)); auto& sptr = args.at(0).sptr;
if (sptr->is_type<TensorNode>()) {
os << args.at(0).operator Tensor();
} else if (sptr->is_type<RDomainNode>()) {
os << args.at(0).operator RDomain();
} else if (sptr->is_type<RangeNode>()) {
os << args.at(0).operator Range();
} else {
os << args.at(0).operator Expr();
}
*ret = os.str(); *ret = os.str();
}) })
.add_argument("expr", "Expr", "expression to be printed"); .add_argument("expr", "Expr", "expression to be printed");
......
...@@ -62,7 +62,17 @@ struct APIVariantValue { ...@@ -62,7 +62,17 @@ struct APIVariantValue {
if (type_id == kNull) return T(); if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle); CHECK_EQ(type_id, kNodeHandle);
std::shared_ptr<Node> x = sptr; std::shared_ptr<Node> x = sptr;
return T(std::move(x)); T inst;
inst.node_ = std::move(x);
return inst;
}
inline operator Expr() const {
if (type_id == kNull) return Expr();
if (type_id == kLong) return IntConstant(operator int64_t());
if (type_id == kDouble) return FloatConstant(operator double());
CHECK_EQ(type_id, kNodeHandle);
std::shared_ptr<Node> x = sptr;
return Expr(std::move(x));
} }
inline operator double() const { inline operator double() const {
CHECK_EQ(type_id, kDouble); CHECK_EQ(type_id, kDouble);
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
*/ */
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/expr_node.h> #include <tvm/expr_node.h>
#include <tvm/expr_util.h>
#include <memory> #include <memory>
namespace tvm { namespace tvm {
...@@ -43,6 +44,24 @@ Expr Tensor::operator()(Array<Expr> indices) const { ...@@ -43,6 +44,24 @@ Expr Tensor::operator()(Array<Expr> indices) const {
return Expr(std::move(node)); return Expr(std::move(node));
} }
std::vector<Tensor> Tensor::InputTensors() const {
const TensorNode* n = static_cast<const TensorNode*>(node_.get());
std::vector<Tensor> inputs;
if (n->source.is_null()) return inputs;
Visit(n->source, [&inputs](const Expr& e) {
if (e.node_type() == kTensorReadNode) {
inputs.push_back(e.Get<TensorReadNode>()->tensor);
}
});
return inputs;
}
bool Tensor::IsRTensor() const {
const TensorNode* n = static_cast<const TensorNode*>(node_.get());
if (n->source.is_null()) return false;
return n->source.node_type() == kReduceNode;
}
TVM_REGISTER_NODE_TYPE(TensorNode); TVM_REGISTER_NODE_TYPE(TensorNode);
} // namespace tvm } // namespace tvm
...@@ -13,6 +13,11 @@ TEST(Tensor, Basic) { ...@@ -13,6 +13,11 @@ TEST(Tensor, Basic) {
auto C = Tensor({m, n}, [&](Var i, Var j) { auto C = Tensor({m, n}, [&](Var i, Var j) {
return sum(A(i, rd.i0()) * B(j, rd.i0()), rd); return sum(A(i, rd.i0()) * B(j, rd.i0()), rd);
}, "C"); }, "C");
auto inputs = C.InputTensors();
CHECK(inputs[0] == A);
CHECK(inputs[1] == B);
CHECK(C.IsRTensor());
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment