Commit d3ee03eb by tqchen

expose range

parent 56e10eb0
......@@ -89,7 +89,7 @@ class Var : public Expr {
};
Expr IntConstant(int64_t value);
Expr FloatConstant(int64_t value);
Expr FloatConstant(double value);
/*! \brief base of expression node */
class ExprNode : public Node {
......
......@@ -40,6 +40,18 @@ inline void Visit(const Expr& expr, FVisit fvisit) {
Visit(n->src, fvisit);
break;
}
case kReduceNode: {
const auto* n = expr.Get<ReduceNode>();
Visit(n->src, fvisit);
break;
}
case kTensorReadNode: {
const auto* n = expr.Get<TensorReadNode>();
for (size_t i = 0; i < n->indices.size(); ++i) {
Visit(n->indices[i], fvisit);
}
break;
}
default: break;
}
fvisit(expr);
......
......@@ -7,6 +7,7 @@
#define TVM_TENSOR_H_
#include <string>
#include <vector>
#include <type_traits>
#include "./expr.h"
#include "./array.h"
......@@ -46,17 +47,17 @@ class TensorNode : public Node {
using FCompute = std::function<Expr (const Array<Var>& i)>;
// converters from other functions into fcompute
inline FCompute GetFCompute(std::function<Expr (Var x)> f) {
return [f](const Array<Var>& i) { return f(i[0]); };
inline FCompute GetFCompute(std::function<Expr(Var x)> f) {
return [f] (const Array<Var>& i) { return f(i[0]); };
}
inline FCompute GetFCompute(std::function<Expr (Var, Var)> f) {
return [f](const Array<Var>& i) { return f(i[0], i[1]); };
inline FCompute GetFCompute(std::function<Expr(Var, Var)> f) {
return [f] (const Array<Var>& i) { return f(i[0], i[1]); };
}
inline FCompute GetFCompute(std::function<Expr (Var, Var, Var)> f) {
return [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); };
inline FCompute GetFCompute(std::function<Expr(Var, Var, Var)> f) {
return [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
}
inline FCompute GetFCompute(std::function<Expr (Var, Var, Var, Var)> f) {
return [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
inline FCompute GetFCompute(std::function<Expr(Var, Var, Var, Var)> f) {
return [f] (const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
}
/*!
......@@ -132,6 +133,10 @@ class Tensor : public NodeRef {
* \return the result expression representing tensor read.
*/
Expr operator()(Array<Expr> indices) const;
/*! \return list of input tensors to this tensor */
std::vector<Tensor> InputTensors() const;
/*! \return whether the tensor stores a result of reduction */
bool IsRTensor() const;
// printt function
friend std::ostream& operator<<(std::ostream &os, const Tensor& t) { // NOLINT(*)
os << "Tensor(shape=" << t.shape()
......
from ._ctypes._api import NodeBase, register_node
@register_node("RangeNode")
class Range(NodeBase):
pass
from ._ctypes._api import NodeBase, register_node
from .function import binary_op
from ._function_internal import _binary_op
class Expr(NodeBase):
def __add__(self, other):
......
......@@ -28,23 +28,6 @@ def _symbol(value):
return value
def binary_op(op, lhs, rhs):
"""Binary operator given op lhs and rhs
Parameters
----------
op : str
The operator string
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return _function_internal._binary_op(op, _symbol(lhs), _symbol(rhs))
def max(lhs, rhs):
"""Max of two expressions
......
......@@ -5,6 +5,7 @@
*/
#include <tvm/expr.h>
#include <tvm/op.h>
#include <tvm/tensor.h>
#include "./c_api_registry.h"
namespace dmlc {
......@@ -37,7 +38,7 @@ TVM_REGISTER_API(constant)
})
.add_argument("src", "Number", "source number");
TVM_REGISTER_API(_binary_op)
TVM_REGISTER_API(binary_op)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kStr);
*ret = (*BinaryOp::Get(args.at(0).str.c_str()))(args.at(1), args.at(2));
......@@ -53,11 +54,36 @@ TVM_REGISTER_API(_raw_ptr)
})
.add_argument("src", "NodeBase", "the node base");
TVM_REGISTER_API(Range)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Range(args.at(0), args.at(1));
})
.add_argument("begin", "Expr", "beginning of the range.")
.add_argument("end", "Expr", "end of the range");
TVM_REGISTER_API(_TensorInput)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Tensor(
static_cast<Array<Expr> >(args.at(0)),
static_cast<std::string>(args.at(1)),
static_cast<DataType>(static_cast<int>(args.at(1))));
});
// transformations
TVM_REGISTER_API(format_str)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os;
os << Expr(args.at(0));
auto& sptr = args.at(0).sptr;
if (sptr->is_type<TensorNode>()) {
os << args.at(0).operator Tensor();
} else if (sptr->is_type<RDomainNode>()) {
os << args.at(0).operator RDomain();
} else if (sptr->is_type<RangeNode>()) {
os << args.at(0).operator Range();
} else {
os << args.at(0).operator Expr();
}
*ret = os.str();
})
.add_argument("expr", "Expr", "expression to be printed");
......
......@@ -62,7 +62,17 @@ struct APIVariantValue {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
std::shared_ptr<Node> x = sptr;
return T(std::move(x));
T inst;
inst.node_ = std::move(x);
return inst;
}
inline operator Expr() const {
if (type_id == kNull) return Expr();
if (type_id == kLong) return IntConstant(operator int64_t());
if (type_id == kDouble) return FloatConstant(operator double());
CHECK_EQ(type_id, kNodeHandle);
std::shared_ptr<Node> x = sptr;
return Expr(std::move(x));
}
inline operator double() const {
CHECK_EQ(type_id, kDouble);
......
......@@ -4,6 +4,7 @@
*/
#include <tvm/tensor.h>
#include <tvm/expr_node.h>
#include <tvm/expr_util.h>
#include <memory>
namespace tvm {
......@@ -43,6 +44,24 @@ Expr Tensor::operator()(Array<Expr> indices) const {
return Expr(std::move(node));
}
std::vector<Tensor> Tensor::InputTensors() const {
const TensorNode* n = static_cast<const TensorNode*>(node_.get());
std::vector<Tensor> inputs;
if (n->source.is_null()) return inputs;
Visit(n->source, [&inputs](const Expr& e) {
if (e.node_type() == kTensorReadNode) {
inputs.push_back(e.Get<TensorReadNode>()->tensor);
}
});
return inputs;
}
bool Tensor::IsRTensor() const {
const TensorNode* n = static_cast<const TensorNode*>(node_.get());
if (n->source.is_null()) return false;
return n->source.node_type() == kReduceNode;
}
TVM_REGISTER_NODE_TYPE(TensorNode);
} // namespace tvm
......@@ -13,6 +13,11 @@ TEST(Tensor, Basic) {
auto C = Tensor({m, n}, [&](Var i, Var j) {
return sum(A(i, rd.i0()) * B(j, rd.i0()), rd);
}, "C");
auto inputs = C.InputTensors();
CHECK(inputs[0] == A);
CHECK(inputs[1] == B);
CHECK(C.IsRTensor());
}
int main(int argc, char ** argv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment