Commit 383494a5 by Tianqi Chen Committed by GitHub

[API] Move all RTTI related code to one place (#20)

* [API] Move all RTTI related code to one place

* add back rtti comment
parent 4d4e19ce
Subproject commit 6375e6b76f6b70d58f66b357d946c971843f3169 Subproject commit af2a2fcee59378f33817d7745a8110b9cc836438
Subproject commit 749e570c19423fe679a5f496e2394ba3bed75a16 Subproject commit 3a51614d39b69fdb5de1efcf1016426626d267a6
...@@ -15,7 +15,7 @@ using RetValue = APIVariantValue; ...@@ -15,7 +15,7 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_pass_Simplify) TVM_REGISTER_API(_pass_Simplify)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) { if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
*ret = Simplify(args.at(0).operator Stmt()); *ret = Simplify(args.at(0).operator Stmt());
} else { } else {
*ret = Simplify(args.at(0).operator Expr()); *ret = Simplify(args.at(0).operator Expr());
...@@ -24,13 +24,10 @@ TVM_REGISTER_API(_pass_Simplify) ...@@ -24,13 +24,10 @@ TVM_REGISTER_API(_pass_Simplify)
TVM_REGISTER_API(_pass_Equal) TVM_REGISTER_API(_pass_Equal)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) { if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
CHECK(args.at(1).type_id == kNodeHandle);
*ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
} else { } else {
Expr a = args.at(0).operator Expr(); *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
Expr b = args.at(1).operator Expr();
*ret = Equal(a, b);
} }
}); });
......
...@@ -33,41 +33,65 @@ inline const char* TypeId2Str(ArgVariantID type_id) { ...@@ -33,41 +33,65 @@ inline const char* TypeId2Str(ArgVariantID type_id) {
template<typename T> template<typename T>
struct NodeTypeChecker { struct NodeTypeChecker {
static inline void Check(Node* sptr) { static inline bool Check(Node* sptr) {
// This is the only place in the project where RTTI is used
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType; using ContainerType = typename T::ContainerType;
// use dynamic RTTI for safety return (dynamic_cast<ContainerType*>(sptr) != nullptr);
CHECK(dynamic_cast<ContainerType*>(sptr)) }
<< "wrong type specified, expected " << typeid(ContainerType).name(); static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
} }
}; };
template<typename T> template<typename T>
struct NodeTypeChecker<Array<T> > { struct NodeTypeChecker<Array<T> > {
static inline void Check(Node* sptr) { static inline bool Check(Node* sptr) {
// use dynamic RTTI for safety if (sptr == nullptr) return false;
CHECK(sptr != nullptr && sptr->is_type<ArrayNode>()) if (!sptr->is_type<ArrayNode>()) return false;
<< "wrong type specified, expected Array";
ArrayNode* n = static_cast<ArrayNode*>(sptr); ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) { for (const auto& p : n->data) {
NodeTypeChecker<T>::Check(p.get()); if (!NodeTypeChecker<T>::Check(p.get())) return false;
} }
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "array<";
NodeTypeChecker<T>::PrintName(os);
os << ">";
} }
}; };
template<typename K, typename V> template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > { struct NodeTypeChecker<Map<K, V> > {
static inline void Check(Node* sptr) { static inline bool Check(Node* sptr) {
// use dynamic RTTI for safety if (sptr == nullptr) return false;
CHECK(sptr != nullptr && sptr->is_type<MapNode>()) if (!sptr->is_type<MapNode>()) return false;
<< "wrong type specified, expected Map";
MapNode* n = static_cast<MapNode*>(sptr); MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) { for (const auto& kv : n->data) {
NodeTypeChecker<K>::Check(kv.first.get()); if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
NodeTypeChecker<V>::Check(kv.second.get()); if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
} }
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<";
NodeTypeChecker<K>::PrintName(os);
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
} }
}; };
template<typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
NodeTypeChecker<T>::PrintName(os);
return os.str();
}
/*! \brief Variant container for API calls */ /*! \brief Variant container for API calls */
class APIVariantValue { class APIVariantValue {
public: public:
...@@ -127,7 +151,8 @@ class APIVariantValue { ...@@ -127,7 +151,8 @@ class APIVariantValue {
inline operator T() const { inline operator T() const {
if (type_id == kNull) return T(); if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle); CHECK_EQ(type_id, kNodeHandle);
NodeTypeChecker<T>::Check(sptr.get()); CHECK(NodeTypeChecker<T>::Check(sptr.get()))
<< "Did not get expected type " << NodeTypeName<T>();
return T(sptr); return T(sptr);
} }
inline operator Expr() const { inline operator Expr() const {
...@@ -140,7 +165,7 @@ class APIVariantValue { ...@@ -140,7 +165,7 @@ class APIVariantValue {
if (sptr->is_type<IterVarNode>()) { if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var; return IterVar(sptr)->var;
} else { } else {
CHECK(dynamic_cast<typename Expr::ContainerType*>(sptr.get())) CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
<< "did not pass in Expr in a place need Expr"; << "did not pass in Expr in a place need Expr";
return Expr(sptr); return Expr(sptr);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment