Commit 0068781d by tqchen

Check in Tensor API on python

parent bcea8f6f
Subproject commit f72e313118a61b0cc49987b9eebfc77300d2de0d Subproject commit bd94f8c8e41b46ae7ca69a3405aac7463a4e23d5
/*!
* Copyright (c) 2016 by Contributors
* \file domain.h
* \brief Defines the domain in AST
*/
#ifndef TVM_DOMAIN_H_
#define TVM_DOMAIN_H_
#include <ir/Range.h>
#include <memory>
#include "./base.h"
#include "./expr.h"
namespace tvm {
/*! \brief container class of reduction domain */
class RDomainNode;
/*!
* \brief same as Halide::IR::Range
* except it provide an constructor with (begin, end)
*
* \note Traditional Halide's Range have a constructor with
* (begin, extent), which does not match the convention in e.g. python.
* We decided to correct it by removing the constructor in HalideIR,
* and add it back in TVM's range.
*/
class Range : public Halide::IR::Range {
public:
/*! \brief constructor */
Range() {}
explicit Range(std::shared_ptr<Node> n) : Halide::IR::Range(n) {}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
Range(Expr begin, Expr end);
};
/*! \brief Domain is a multi-dimensional range */
using Domain = Array<Range>;
/*! \brief reduction domain */
class RDomain : public NodeRef {
public:
/*! \brief constructor*/
RDomain() {}
/*!
* constructor by domain
* \param domain The domain of reduction.
*/
explicit RDomain(Domain domain);
/*!
* \brief constructor by list of ranges
* \param domain The reduction domain
*/
explicit RDomain(std::initializer_list<Range> domain)
: RDomain(Domain(domain)) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const RDomainNode* operator->() const;
/*! \return The dimension of the RDomain */
inline size_t ndim() const;
/*!
* \param i the index.
* \return i-th index variable in the RDomain
*/
inline Var index(size_t i) const;
/*! \return the 0-th index of the domain */
inline Var i0() const {
return index(0);
}
};
/*! \brief use RDom as alias of RDomain */
using RDom = RDomain;
/*! \brief reduction domain node */
class RDomainNode : public Node {
public:
/*! \brief internal index */
Array<Var> index;
/*! \brief The inernal domain */
Domain domain;
/*! \brief constructor */
RDomainNode() {}
RDomainNode(Array<Var> && index, Domain && domain)
: index(std::move(index)), domain(std::move(domain)) {
}
const char* type_key() const override {
return "RDomain";
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("index", &index);
v->Visit("domain", &domain);
}
};
inline const RDomainNode* RDomain::operator->() const {
return static_cast<const RDomainNode*>(node_.get());
}
inline size_t RDomain::ndim() const {
return (*this)->index.size();
}
inline Var RDomain::index(size_t i) const {
return (*this)->index[i];
}
// overload print function
inline std::ostream& operator<<(std::ostream &os, const RDomain& r){ // NOLINT(*)
os << "rdomain(" << r->domain << ")";
return os;
}
} // namespace tvm
#endif // TVM_DOMAIN_H_
...@@ -27,6 +27,7 @@ using Halide::abs; ...@@ -27,6 +27,7 @@ using Halide::abs;
using Halide::select; using Halide::select;
using Halide::Expr; using Halide::Expr;
using Halide::Internal::Stmt;
using Var = Halide::VarExpr; using Var = Halide::VarExpr;
} // namespace tvm } // namespace tvm
......
...@@ -6,11 +6,12 @@ ...@@ -6,11 +6,12 @@
#ifndef TVM_TENSOR_H_ #ifndef TVM_TENSOR_H_
#define TVM_TENSOR_H_ #define TVM_TENSOR_H_
#include <tvm/array.h>
#include <ir/FunctionBase.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <type_traits> #include <type_traits>
#include <tvm/array.h>
#include <ir/FunctionBase.h>
#include "./base.h" #include "./base.h"
#include "./expr.h" #include "./expr.h"
...@@ -46,6 +47,7 @@ class Tensor : public FunctionRef { ...@@ -46,6 +47,7 @@ class Tensor : public FunctionRef {
public: public:
/*! \brief default constructor, used internally */ /*! \brief default constructor, used internally */
Tensor() {} Tensor() {}
explicit Tensor(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*! /*!
* \brief constructor of input tensor * \brief constructor of input tensor
* \param shape Shape of the tensor. * \param shape Shape of the tensor.
...@@ -101,14 +103,14 @@ class Tensor : public FunctionRef { ...@@ -101,14 +103,14 @@ class Tensor : public FunctionRef {
/*! \brief Node to represent a tensor */ /*! \brief Node to represent a tensor */
class TensorNode : public Node { class TensorNode : public Node {
public: public:
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief optional name of the tensor */ /*! \brief optional name of the tensor */
std::string name; std::string name;
/*! \brief data type in the content of the tensor */ /*! \brief data type in the content of the tensor */
Type dtype; Type dtype;
/*! \brief The index representing each dimension, used by source expression. */ /*! \brief The index representing each dimension, used by source expression. */
Array<Var> dim_var; Array<Var> dim_var;
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief source expression */ /*! \brief source expression */
Expr source; Expr source;
/*! \brief constructor */ /*! \brief constructor */
...@@ -117,13 +119,17 @@ class TensorNode : public Node { ...@@ -117,13 +119,17 @@ class TensorNode : public Node {
return "Tensor"; return "Tensor";
} }
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
v->Visit("dim_var", &dim_var); v->Visit("dim_var", &dim_var);
v->Visit("shape", &shape);
v->Visit("source", &source); v->Visit("source", &source);
} }
static Tensor make(Array<Expr> shape,
std::string name,
Type dtype,
Array<Var> dim_var,
Expr source);
}; };
// implementations // implementations
......
...@@ -7,3 +7,4 @@ from . import expr ...@@ -7,3 +7,4 @@ from . import expr
from . import stmt from . import stmt
from . import make from . import make
from . import collections from . import collections
from . import tensor
...@@ -107,13 +107,13 @@ def convert(value): ...@@ -107,13 +107,13 @@ def convert(value):
"""Convert a value to expression.""" """Convert a value to expression."""
if isinstance(value, Number): if isinstance(value, Number):
return const(value) return const(value)
elif isinstance(value, list): elif isinstance(value, (list, tuple)):
value = [convert(x) for x in value] value = [convert(x) for x in value]
return _function_internal._Array(*value) return _function_internal._Array(*value)
else: else:
if not isinstance(value, NodeBase): if not isinstance(value, NodeBase):
raise ValueError("don't know how to handle type %s" % type(value)) raise ValueError("don't know how to handle type %s" % type(value))
return value
def _push_arg(arg): def _push_arg(arg):
a = ArgVariant() a = ArgVariant()
...@@ -172,7 +172,7 @@ def _make_function(handle, name): ...@@ -172,7 +172,7 @@ def _make_function(handle, name):
"""TVM function""" """TVM function"""
cargs = [] cargs = []
for x in args: for x in args:
if isinstance(x, list): if isinstance(x, (list, tuple)):
cargs.append(convert(x)) cargs.append(convert(x))
else: else:
cargs.append(x) cargs.append(x)
......
"""Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node from ._ctypes._api import NodeBase, register_node
from . import _function_internal from . import _function_internal
......
...@@ -54,6 +54,7 @@ class LogicalExpr(Expr): ...@@ -54,6 +54,7 @@ class LogicalExpr(Expr):
@register_node("Variable") @register_node("Variable")
class Var(Expr): class Var(Expr):
pass pass
@register_node @register_node
...@@ -162,6 +163,12 @@ class Broadcast(Expr): ...@@ -162,6 +163,12 @@ class Broadcast(Expr):
@register_node @register_node
class Call(Expr): class Call(Expr):
Extern = 0
ExternCPlusPlus = 1
PureExtern = 2
Halide = 3
Intrinsic = 4
PureIntrinsic = 5
pass pass
@register_node @register_node
......
...@@ -35,33 +35,45 @@ def convert(value): ...@@ -35,33 +35,45 @@ def convert(value):
"""Convert a value to expression.""" """Convert a value to expression."""
if isinstance(value, _Number): if isinstance(value, _Number):
return const(value) return const(value)
elif isinstance(value, list): elif isinstance(value, (list, tuple)):
value = [convert(x) for x in value] value = [convert(x) for x in value]
return _function_internal._Array(*value) return _function_internal._Array(*value)
else: else:
return value return value
def Range(begin, **kwargs): def Tensor(shape, fcompute=None, dtype=None, name="TensorObj"):
"""Create a TVM Range object. """Construct a tensor object in dataflow.
User can either call:
Range(10) to get a range in [0, 10)
or
Range(begin=1, extent=10), to get a range in [0, 11)
Parameters Parameters
---------- ----------
begin : Expr shape: Tuple of Expr
The beginning of the expression. The shape of the tensor
fcompute: lambda function of *indices-> value
Specifies the input source expression
extent : optional, Expr dtype: str, optional
The extent(i.e. the length) of the range. The data type of the tensor, must specify when fcompute is not specified.
name: str, optional
The name hint of the tensor
Returns
-------
tensor: tensor.Tensor
The created tensor
""" """
if "extent" in kwargs: ndim = len(shape)
return _function_internal._Range(begin, kwargs["extent"]) dim_var = [Var("dim_var%d" % i) for i in range(ndim)]
if fcompute:
source = fcompute(*dim_var)
return _function_internal._Tensor(
shape, name, source.dtype, dim_var, source)
else: else:
return _function_internal._Range(0, begin); dtype = float32 if dtype is None else dtype
return _function_internal._Tensor(
shape, name, dtype, None, None)
_init_function_module("tvm") _init_function_module("tvm")
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
from . import make as _make
from . import expr as _expr
@register_node
class Tensor(NodeBase):
"""Tensor object, to construct, see function.Tensor"""
def __call__(self, *indices):
ndim = self.ndim
if len(indices) != ndim:
raise ValueError("Need to provide %d index in tensor slice" % ndim)
return _make.Call(self.dtype, self.name, indices, _expr.Call.Halide, self, 0)
@property
def ndim(self):
return len(self.shape)
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
* \file c_api_impl.cc * \file c_api_impl.cc
*/ */
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/domain.h>
#include <tvm/tensor.h>
#include <ir/IROperator.h> #include <ir/IROperator.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
...@@ -13,30 +15,22 @@ DMLC_REGISTRY_ENABLE(::tvm::APIFunctionReg); ...@@ -13,30 +15,22 @@ DMLC_REGISTRY_ENABLE(::tvm::APIFunctionReg);
namespace tvm { namespace tvm {
using namespace Halide::Internal;
using ArgStack = const std::vector<APIVariantValue>; using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue; using RetValue = APIVariantValue;
TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.at(0).type_id == kLong) {
*ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) {
*ret = make_const(args.at(1), args.at(0).operator double());
} else {
LOG(FATAL) << "only accept int or float";
}
})
.add_argument("src", "Number", "source number")
.add_argument("dtype", "str", "data type");
TVM_REGISTER_API(format_str) TVM_REGISTER_API(format_str)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::BaseExprNode;
using Halide::Internal::BaseStmtNode;
CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os; std::ostringstream os;
auto& sptr = args.at(0).sptr; auto& sptr = args.at(0).sptr;
if (dynamic_cast<const BaseExprNode*>(sptr.get())) { if (sptr->is_type<TensorNode>()) {
os << args.at(0).operator Tensor();
} else if (sptr->is_type<RDomainNode>()) {
os << args.at(0).operator RDomain();
} else if (dynamic_cast<const BaseExprNode*>(sptr.get())) {
os << args.at(0).operator Expr(); os << args.at(0).operator Expr();
} else if (dynamic_cast<const BaseStmtNode*>(sptr.get())) { } else if (dynamic_cast<const BaseStmtNode*>(sptr.get())) {
os << args.at(0).operator Stmt(); os << args.at(0).operator Stmt();
...@@ -47,46 +41,11 @@ TVM_REGISTER_API(format_str) ...@@ -47,46 +41,11 @@ TVM_REGISTER_API(format_str)
}) })
.add_argument("expr", "Node", "expression to be printed"); .add_argument("expr", "Node", "expression to be printed");
TVM_REGISTER_API(_Array) TVM_REGISTER_API(_raw_ptr)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle);
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(sptr.get());
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_id = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(0).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr; *ret = reinterpret_cast<int64_t>(args.at(0).sptr.get());
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
TVM_REGISTER_API(_Range)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Range(args.at(0), args.at(1));
}) })
.add_argument("min", "Expr", "beginning of the range.") .add_argument("src", "NodeBase", "the node base");
.add_argument("extent", "Expr", "extent of the range");
} // namespace tvm } // namespace tvm
...@@ -29,6 +29,16 @@ TVM_REGISTER_API(_make_For) ...@@ -29,6 +29,16 @@ TVM_REGISTER_API(_make_For)
args.at(5)); args.at(5));
}); });
TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Call::make(args.at(0),
args.at(1),
args.at(2),
static_cast<Call::CallType>(args.at(3).operator int()),
args.at(4));
});
TVM_REGISTER_API(_make_Allocate) TVM_REGISTER_API(_make_Allocate)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = Allocate::make(args.at(0), *ret = Allocate::make(args.at(0),
...@@ -91,7 +101,6 @@ REGISTER_MAKE3(Select); ...@@ -91,7 +101,6 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp); REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Broadcast); REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let); REGISTER_MAKE3(Let);
// TODO(tqchen) Call;
REGISTER_MAKE3(LetStmt); REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt); REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer); REGISTER_MAKE3(ProducerConsumer);
......
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to Higher DSL build.
* \file c_api_lang.cc
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/domain.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
namespace tvm {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::make_const;
if (args.at(0).type_id == kLong) {
*ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) {
*ret = make_const(args.at(1), args.at(0).operator double());
} else {
LOG(FATAL) << "only accept int or float";
}
})
.add_argument("src", "Number", "source number")
.add_argument("dtype", "str", "data type");
TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle)
<< "need content of array to be NodeBase";
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(sptr.get());
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_id = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
TVM_REGISTER_API(Range)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.size() == 1) {
*ret = Range(0, args.at(0));
} else {
*ret = Range(args.at(0), args.at(1));
}
})
.add_argument("begin", "Expr", "beginning of the range.")
.add_argument("end", "Expr", "extent of the range");
TVM_REGISTER_API(_Tensor)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
});
TVM_REGISTER_API(_RDomain)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = RDomain(args.at(0).operator Domain());
});
} // namespace tvm
...@@ -80,8 +80,12 @@ struct APIVariantValue { ...@@ -80,8 +80,12 @@ struct APIVariantValue {
return *this; return *this;
} }
inline APIVariantValue& operator=(const NodeRef& ref) { inline APIVariantValue& operator=(const NodeRef& ref) {
type_id = kNodeHandle; if (ref.node_.get() == nullptr) {
this->sptr = ref.node_; type_id = kNull;
} else {
type_id = kNodeHandle;
this->sptr = ref.node_;
}
return *this; return *this;
} }
inline APIVariantValue& operator=(const Type& value) { inline APIVariantValue& operator=(const Type& value) {
......
/*!
* Copyright (c) 2016 by Contributors
* \file domain.cc
*/
#include <tvm/base.h>
#include <tvm/domain.h>
namespace tvm {
Range::Range(Expr begin, Expr end)
: Range(std::make_shared<Halide::IR::RangeNode>(begin, end - begin)) {
// TODO(tqchen) add simplify to end - begin
}
RDomain::RDomain(Domain domain) {
std::vector<Var> index;
for (size_t i = 0; i < domain.size(); ++i) {
std::ostringstream os;
os << "reduction_index" << i;
index.push_back(Var(os.str()));
}
Array<Var> idx(index);
node_ = std::make_shared<RDomainNode>(
std::move(idx), std::move(domain));
}
TVM_REGISTER_NODE_TYPE(RDomainNode);
} // namespace tvm
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file expr_node.cc * \file ir_node.cc
*/ */
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/expr.h> #include <tvm/expr.h>
......
...@@ -42,6 +42,20 @@ Expr Tensor::operator()(Array<Expr> indices) const { ...@@ -42,6 +42,20 @@ Expr Tensor::operator()(Array<Expr> indices) const {
(*this)->dtype, (*this)->name, indices, Call::Halide, *this); (*this)->dtype, (*this)->name, indices, Call::Halide, *this);
} }
Tensor TensorNode::make(Array<Expr> shape,
std::string name,
Type dtype,
Array<Var> dim_var,
Expr source) {
auto n = std::make_shared<TensorNode>();
n->shape = shape;
n->name = name;
n->dtype = dtype;
n->dim_var = dim_var;
n->source = source;
return Tensor(n);
}
TVM_REGISTER_NODE_TYPE(TensorNode); TVM_REGISTER_NODE_TYPE(TensorNode);
} // namespace tvm } // namespace tvm
import tvm
def test_tensor():
m = tvm.Var('m')
n = tvm.Var('n')
l = tvm.Var('l')
A = tvm.Tensor((m, l), name='A')
B = tvm.Tensor((n, l), name='B')
T = tvm.Tensor((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
print(tvm.format_str(T.source))
assert(tuple(T.shape) == (m, n, l))
assert(A.source is None)
if __name__ == "__main__":
test_tensor()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment