Commit 5324b211 by tqchen

[API] expose dir

parent 8de0a083
...@@ -131,6 +131,20 @@ class NodeRef { ...@@ -131,6 +131,20 @@ class NodeRef {
inline NodeType node_type() const; inline NodeType node_type() const;
/*! \return wheyjer the expression is null */ /*! \return wheyjer the expression is null */
inline bool is_null() const; inline bool is_null() const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator==(const NodeRef& other) const;
/*!
* \brief Comparator
* \param other Another node ref.
* \return the compare result.
*/
inline bool operator!=(const NodeRef& other) const;
/*! \return the hash function for NodeRef */
inline size_t hash() const;
protected: protected:
template<typename T, typename> template<typename T, typename>
...@@ -182,5 +196,26 @@ inline bool NodeRef::is_null() const { ...@@ -182,5 +196,26 @@ inline bool NodeRef::is_null() const {
return node_.get() == nullptr; return node_.get() == nullptr;
} }
inline bool NodeRef::operator==(const NodeRef& other) const {
return node_.get() == other.node_.get();
}
inline bool NodeRef::operator!=(const NodeRef& other) const {
return node_.get() != other.node_.get();
}
inline size_t NodeRef::hash() const {
return std::hash<Node*>()(node_.get());
}
} // namespace tvm } // namespace tvm
namespace std {
template <>
struct hash<::tvm::NodeRef> {
std::size_t operator()(const ::tvm::NodeRef& k) const {
return k.hash();
}
};
} // namespace std
#endif // TVM_BASE_H_ #endif // TVM_BASE_H_
...@@ -136,4 +136,14 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle, ...@@ -136,4 +136,14 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
ArgVariant* out_value, ArgVariant* out_value,
int* out_typeid); int* out_typeid);
/*!
* \brief get attributes names in the node.
* \param handle The node handle
* \param out_size The number of functions
* \param out_array The array of function names.
*/
TVM_DLL int TVMNodeListAttrNames(NodeHandle handle,
int *out_size,
const char*** out_array);
#endif // TVM_C_API_H_ #endif // TVM_C_API_H_
...@@ -14,9 +14,13 @@ ...@@ -14,9 +14,13 @@
namespace tvm { namespace tvm {
// using Domain = Array<Range>; //using Domain = Array<Range>;
class RDomain : public NodeRef {
};
} // namespace tvm } // namespace tvm
......
...@@ -113,4 +113,13 @@ inline Expr constant(T value) { ...@@ -113,4 +113,13 @@ inline Expr constant(T value) {
} }
} // namespace tvm } // namespace tvm
namespace std {
template <>
struct hash<::tvm::Expr> {
std::size_t operator()(const ::tvm::NodeRef& k) const {
return k.hash();
}
};
} // namespace std
#endif // TVM_EXPR_H_ #endif // TVM_EXPR_H_
...@@ -46,6 +46,7 @@ class IntNode : public ExprNode { ...@@ -46,6 +46,7 @@ class IntNode : public ExprNode {
} }
void VisitAttrs(AttrVisitor* visitor) override { void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("value", &value); visitor->Visit("value", &value);
visitor->Visit("dtype", &dtype_);
} }
}; };
...@@ -64,6 +65,7 @@ class FloatNode : public ExprNode { ...@@ -64,6 +65,7 @@ class FloatNode : public ExprNode {
} }
void VisitAttrs(AttrVisitor* visitor) override { void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("value", &value); visitor->Visit("value", &value);
visitor->Visit("dtype", &dtype_);
} }
}; };
...@@ -94,6 +96,7 @@ class UnaryOpNode : public ExprNode { ...@@ -94,6 +96,7 @@ class UnaryOpNode : public ExprNode {
} }
void VisitAttrs(AttrVisitor* visitor) override { void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("op", &op); visitor->Visit("op", &op);
visitor->Visit("dtype", &dtype_);
} }
void VisitNodeRefFields(FNodeRefVisit fvisit) override { void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("src", &src); fvisit("src", &src);
...@@ -130,12 +133,51 @@ struct BinaryOpNode : public ExprNode { ...@@ -130,12 +133,51 @@ struct BinaryOpNode : public ExprNode {
} }
void VisitAttrs(AttrVisitor* visitor) override { void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("op", &op); visitor->Visit("op", &op);
visitor->Visit("dtype", &dtype_);
} }
void VisitNodeRefFields(FNodeRefVisit fvisit) override { void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("lhs", &lhs); fvisit("lhs", &lhs);
fvisit("rhs", &rhs); fvisit("rhs", &rhs);
} }
}; };
/*! \brief Binary mapping operator */
struct ReduceNode : public ExprNode {
public:
/*! \brief The operator */
const BinaryOp* op;
/*! \brief The source operand */
Expr src;
/*! \brief The reduction domain */
RDomain rdom;
/*! \brief constructor, do not use constructor */
ReduceNode() {
node_type_ = kReduceNode;
}
ReduceNode(const BinaryOp* op, Expr && src, RDomain && rdom)
: op(op), src(std::move(src)), rdom(std::move(rdom)) {
node_type_ = kReduceNode;
dtype_ = this->src.dtype();
}
~ReduceNode() {
this->Destroy();
}
const char* type_key() const override {
return "ReduceNode";
}
void Verify() const override {
CHECK_EQ(dtype_, src.dtype());
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("op", &op);
visitor->Visit("dtype", &dtype_);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("src", &src);
fvisit("rdom", &rdom);
}
};
} // namespace tvm } // namespace tvm
#endif // TVM_EXPR_NODE_H_ #endif // TVM_EXPR_NODE_H_
...@@ -7,9 +7,41 @@ ...@@ -7,9 +7,41 @@
#define TVM_EXPR_UTIL_H_ #define TVM_EXPR_UTIL_H_
#include "./expr.h" #include "./expr.h"
#include "./expr_node.h"
namespace tvm { namespace tvm {
/*!
* \brief simplify the expression src
* \param src The source expression
* \return the simplified expression.
*/
Expr Simplify(const Expr& src);
/*!
* \brief visit the exression node in expr tree in post DFS order.
* \param expr The expression tree
* \param fvisit The visit function.
*/
template<typename FVisit>
inline void Visit(const Expr& expr, FVisit fvisit) {
// TODO(tqchen) change to stack based impl.
switch (expr.node_type()) {
case kBinaryOpNode: {
const auto* n = expr.Get<BinaryOpNode>();
Visit(n->lhs, fvisit);
Visit(n->rhs, fvisit);
break;
}
case kUnaryOpNode: {
const auto* n = expr.Get<UnaryOpNode>();
Visit(n->src, fvisit);
break;
}
default: break;
}
fvisit(expr);
}
} // namespace tvm } // namespace tvm
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <string> #include <string>
#include "./expr.h" #include "./expr.h"
#include "./domain.h"
namespace tvm { namespace tvm {
...@@ -27,6 +28,13 @@ class BinaryOp { ...@@ -27,6 +28,13 @@ class BinaryOp {
*/ */
Expr operator()(Expr lhs, Expr rhs) const; Expr operator()(Expr lhs, Expr rhs) const;
/*! /*!
* \brief make a reduction of src over rdom,
* \param src Source expression.
* \param rdom reduction domain.
* \return the result expr
*/
Expr Reduce(Expr src, RDomain rdom) const;
/*!
* \brief get binary op by name * \brief get binary op by name
* \param name name of operator * \param name name of operator
*/ */
...@@ -112,6 +120,12 @@ class MinOp : public BinaryOp { ...@@ -112,6 +120,12 @@ class MinOp : public BinaryOp {
return (*op)(lhs, rhs); \ return (*op)(lhs, rhs); \
} }
#define DEFINE_REDUCE_FUNCTION(FuncName, OpName) \
inline Expr FuncName(Expr src, RDomain rdom) { \
static const BinaryOp* op = BinaryOp::Get(#OpName); \
return op->Reduce(src, rdom); \
}
DEFINE_BINARY_OP_OVERLOAD(+); DEFINE_BINARY_OP_OVERLOAD(+);
DEFINE_BINARY_OP_OVERLOAD(-); DEFINE_BINARY_OP_OVERLOAD(-);
DEFINE_BINARY_OP_OVERLOAD(*); DEFINE_BINARY_OP_OVERLOAD(*);
...@@ -120,6 +134,10 @@ DEFINE_BINARY_OP_OVERLOAD(/); ...@@ -120,6 +134,10 @@ DEFINE_BINARY_OP_OVERLOAD(/);
DEFINE_BINARY_OP_FUNCTION(max); DEFINE_BINARY_OP_FUNCTION(max);
DEFINE_BINARY_OP_FUNCTION(min); DEFINE_BINARY_OP_FUNCTION(min);
DEFINE_REDUCE_FUNCTION(max, max);
DEFINE_REDUCE_FUNCTION(min, min);
DEFINE_REDUCE_FUNCTION(sum, +);
// overload negation // overload negation
inline Expr operator-(Expr src) { inline Expr operator-(Expr src) {
return src * (-1); return src * (-1);
......
...@@ -11,6 +11,7 @@ from .._base import _LIB ...@@ -11,6 +11,7 @@ from .._base import _LIB
from .._base import c_str, py_str, string_types from .._base import c_str, py_str, string_types
from .._base import FunctionHandle, NodeHandle from .._base import FunctionHandle, NodeHandle
from .._base import check_call, ctypes2docstring from .._base import check_call, ctypes2docstring
from .. import _function_internal
class ArgVariant(ctypes.Union): class ArgVariant(ctypes.Union):
...@@ -71,6 +72,27 @@ class NodeBase(object): ...@@ -71,6 +72,27 @@ class NodeBase(object):
ctypes.byref(ret_val), ctypes.byref(ret_typeid))) ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return RET_SWITCH[ret_typeid.value](ret_val) return RET_SWITCH[ret_typeid.value](ret_val)
def __hash__(self):
return _function_internal._raw_ptr(self)
def __eq__(self, other):
if not isinstance(other, NodeBase):
return False
return self.__hash__() == other.__hash__()
def __ne__(self, other):
return not self.__eq__(other)
def __dir__(self):
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMNodeListAttrNames(
self.handle, ctypes.byref(size), ctypes.byref(plist)))
names = []
for i in range(size.value):
names.append(py_str(plist[i]))
return names
def _push_arg(arg): def _push_arg(arg):
a = ArgVariant() a = ArgVariant()
......
...@@ -42,5 +42,5 @@ class Var(Expr): ...@@ -42,5 +42,5 @@ class Var(Expr):
pass pass
@register_node("BinaryOpNode") @register_node("BinaryOpNode")
class BinaryOpNode(Expr): class BinaryOpExpr(Expr):
pass pass
...@@ -59,6 +59,29 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -59,6 +59,29 @@ struct APIAttrGetter : public AttrVisitor {
} }
}; };
struct APIAttrDir : public AttrVisitor {
std::vector<std::string>* names;
void Visit(const char* key, double* value) override {
names->push_back(key);
}
void Visit(const char* key, int64_t* value) override {
names->push_back(key);
}
void Visit(const char* key, DataType* value) override {
names->push_back(key);
}
void Visit(const char* key, std::string* value) override {
names->push_back(key);
}
void Visit(const char* key, const UnaryOp** value) override {
names->push_back(key);
}
void Visit(const char* key, const BinaryOp** value) override {
names->push_back(key);
}
};
const char *TVMGetLastError() { const char *TVMGetLastError() {
return TVMAPIThreadLocalStore::Get()->last_error.c_str(); return TVMAPIThreadLocalStore::Get()->last_error.c_str();
} }
...@@ -190,6 +213,29 @@ int TVMNodeGetAttr(NodeHandle handle, ...@@ -190,6 +213,29 @@ int TVMNodeGetAttr(NodeHandle handle,
API_END_HANDLE_ERROR(ret->Clear()); API_END_HANDLE_ERROR(ret->Clear());
} }
int TVMNodeListAttrNames(NodeHandle handle,
int *out_size,
const char*** out_array) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str.clear();
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
APIAttrDir dir;
dir.names = &(ret->ret_vec_str);
(*tnode)->VisitAttrs(&dir);
(*tnode)->VisitNodeRefFields([ret](const char* key, NodeRef* ref) {
ret->ret_vec_str.push_back(key);
});
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<int>(ret->ret_vec_str.size());
API_END();
}
inline void TVMAPIThreadLocalEntry::SetReturn(ArgVariant* ret_val, inline void TVMAPIThreadLocalEntry::SetReturn(ArgVariant* ret_val,
int* ret_typeid) { int* ret_typeid) {
APIVariantValue& rv = ret_value; APIVariantValue& rv = ret_value;
......
...@@ -46,6 +46,13 @@ TVM_REGISTER_API(_binary_op) ...@@ -46,6 +46,13 @@ TVM_REGISTER_API(_binary_op)
.add_argument("lhs", "Expr", "left operand") .add_argument("lhs", "Expr", "left operand")
.add_argument("rhs", "Expr", "right operand"); .add_argument("rhs", "Expr", "right operand");
TVM_REGISTER_API(_raw_ptr)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
*ret = reinterpret_cast<int64_t>(args.at(0).sptr.get());
})
.add_argument("src", "NodeBase", "the node base");
// transformations // transformations
TVM_REGISTER_API(format_str) TVM_REGISTER_API(format_str)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
......
...@@ -20,6 +20,13 @@ Expr BinaryOp::operator()(Expr lhs, Expr rhs) const { ...@@ -20,6 +20,13 @@ Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
return Expr(std::move(nptr)); return Expr(std::move(nptr));
} }
Expr BinaryOp::Reduce(Expr src, RDomain rdom) const {
auto nptr = std::make_shared<ReduceNode>(
this, std::move(src), std::move(rdom));
nptr->Verify();
return Expr(std::move(nptr));
}
const BinaryOp* BinaryOp::Get(const char* name) { const BinaryOp* BinaryOp::Get(const char* name) {
const auto* op = dmlc::Registry<BinaryOpReg>::Find(name); const auto* op = dmlc::Registry<BinaryOpReg>::Find(name);
CHECK(op != nullptr) << "cannot find " << name; CHECK(op != nullptr) << "cannot find " << name;
......
...@@ -4,6 +4,9 @@ def test_basic(): ...@@ -4,6 +4,9 @@ def test_basic():
a = tvm.Var('a') a = tvm.Var('a')
b = tvm.Var('b') b = tvm.Var('b')
c = a + b c = a + b
assert a == c.lhs
assert c.dtype == tvm.int32
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name) assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment