Commit 383494a5 by Tianqi Chen Committed by GitHub

[API] Move all RTTI related code to one place (#20)

* [API] Move all RTTI related code to one place

* add back rtti comment
parent 4d4e19ce
Subproject commit 6375e6b76f6b70d58f66b357d946c971843f3169
Subproject commit af2a2fcee59378f33817d7745a8110b9cc836438
Subproject commit 749e570c19423fe679a5f496e2394ba3bed75a16
Subproject commit 3a51614d39b69fdb5de1efcf1016426626d267a6
......@@ -15,7 +15,7 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_pass_Simplify)
.set_body([](const ArgStack& args, RetValue *ret) {
if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) {
if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
*ret = Simplify(args.at(0).operator Stmt());
} else {
*ret = Simplify(args.at(0).operator Expr());
......@@ -24,13 +24,10 @@ TVM_REGISTER_API(_pass_Simplify)
TVM_REGISTER_API(_pass_Equal)
.set_body([](const ArgStack& args, RetValue *ret) {
if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) {
CHECK(args.at(1).type_id == kNodeHandle);
if (NodeTypeChecker<Stmt>::Check(args.at(0).sptr.get())) {
*ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
} else {
Expr a = args.at(0).operator Expr();
Expr b = args.at(1).operator Expr();
*ret = Equal(a, b);
*ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
}
});
......
......@@ -33,41 +33,65 @@ inline const char* TypeId2Str(ArgVariantID type_id) {
template<typename T>
struct NodeTypeChecker {
static inline void Check(Node* sptr) {
static inline bool Check(Node* sptr) {
// This is the only place in the project where RTTI is used
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
// use dynamic RTTI for safety
CHECK(dynamic_cast<ContainerType*>(sptr))
<< "wrong type specified, expected " << typeid(ContainerType).name();
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};
template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline void Check(Node* sptr) {
// use dynamic RTTI for safety
CHECK(sptr != nullptr && sptr->is_type<ArrayNode>())
<< "wrong type specified, expected Array";
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
NodeTypeChecker<T>::Check(p.get());
if (!NodeTypeChecker<T>::Check(p.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "array<";
NodeTypeChecker<T>::PrintName(os);
os << ">";
}
};
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline void Check(Node* sptr) {
// use dynamic RTTI for safety
CHECK(sptr != nullptr && sptr->is_type<MapNode>())
<< "wrong type specified, expected Map";
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<MapNode>()) return false;
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
NodeTypeChecker<K>::Check(kv.first.get());
NodeTypeChecker<V>::Check(kv.second.get());
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<";
NodeTypeChecker<K>::PrintName(os);
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};
template<typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
NodeTypeChecker<T>::PrintName(os);
return os.str();
}
/*! \brief Variant container for API calls */
class APIVariantValue {
public:
......@@ -127,7 +151,8 @@ class APIVariantValue {
inline operator T() const {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
NodeTypeChecker<T>::Check(sptr.get());
CHECK(NodeTypeChecker<T>::Check(sptr.get()))
<< "Did not get expected type " << NodeTypeName<T>();
return T(sptr);
}
inline operator Expr() const {
......@@ -140,7 +165,7 @@ class APIVariantValue {
if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
} else {
CHECK(dynamic_cast<typename Expr::ContainerType*>(sptr.get()))
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
<< "did not pass in Expr in a place need Expr";
return Expr(sptr);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment