Commit eee0ebef by tqchen

Stronger type checker during conversion

parent 57a74936
...@@ -27,6 +27,7 @@ using Halide::IR::FunctionRef; ...@@ -27,6 +27,7 @@ using Halide::IR::FunctionRef;
using Halide::IR::FunctionBaseNode; using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt; using Halide::Internal::Stmt;
using Halide::Internal::IRPrinter; using Halide::Internal::IRPrinter;
using Halide::Internal::Variable;
/*! \brief a named variable in TVM */ /*! \brief a named variable in TVM */
class Var : public Halide::VarExpr { class Var : public Halide::VarExpr {
...@@ -35,6 +36,9 @@ class Var : public Halide::VarExpr { ...@@ -35,6 +36,9 @@ class Var : public Halide::VarExpr {
Type t = Int(32)) : VarExpr(name_hint, t) {} Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {} explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
/*! \brief type indicate the container type */
using ContainerType = Variable;
}; };
......
...@@ -83,7 +83,6 @@ using Halide::Internal::UIntImm; ...@@ -83,7 +83,6 @@ using Halide::Internal::UIntImm;
using Halide::Internal::FloatImm; using Halide::Internal::FloatImm;
using Halide::Internal::StringImm; using Halide::Internal::StringImm;
using Halide::Internal::Cast; using Halide::Internal::Cast;
using Halide::Internal::Variable;
using Halide::Internal::Add; using Halide::Internal::Add;
using Halide::Internal::Sub; using Halide::Internal::Sub;
using Halide::Internal::Mul; using Halide::Internal::Mul;
......
...@@ -10,4 +10,5 @@ from . import ir_pass ...@@ -10,4 +10,5 @@ from . import ir_pass
from . import collections from . import collections
from . import schedule from . import schedule
from ._base import TVMError
from .function import * from .function import *
...@@ -54,6 +54,43 @@ inline const char* TypeId2Str(ArgVariantID type_id) { ...@@ -54,6 +54,43 @@ inline const char* TypeId2Str(ArgVariantID type_id) {
} }
} }
template<typename T>
struct NodeTypeChecker {
static inline void Check(Node* sptr) {
using ContainerType = typename T::ContainerType;
// use dynamic RTTI for safety
CHECK(dynamic_cast<ContainerType*>(sptr))
<< "wrong type specified, expected " << typeid(ContainerType).name();
}
};
template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline void Check(Node* sptr) {
// use dynamic RTTI for safety
CHECK(sptr != nullptr && sptr->is_type<ArrayNode>())
<< "wrong type specified, expected Array";
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
NodeTypeChecker<T>::Check(p.get());
}
}
};
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline void Check(Node* sptr) {
// use dynamic RTTI for safety
CHECK(sptr != nullptr && sptr->is_type<MapNode>())
<< "wrong type specified, expected Map";
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
NodeTypeChecker<K>::Check(kv.first.get());
NodeTypeChecker<V>::Check(kv.second.get());
}
}
};
/*! \brief Variant container for API calls */ /*! \brief Variant container for API calls */
class APIVariantValue { class APIVariantValue {
public: public:
...@@ -113,9 +150,7 @@ class APIVariantValue { ...@@ -113,9 +150,7 @@ class APIVariantValue {
inline operator T() const { inline operator T() const {
if (type_id == kNull) return T(); if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle); CHECK_EQ(type_id, kNodeHandle);
// use dynamic RTTI for safety NodeTypeChecker<T>::Check(sptr.get());
CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get()))
<< "wrong type specified, expected " << typeid(typename T::ContainerType).name();
return T(sptr); return T(sptr);
} }
inline operator Expr() const { inline operator Expr() const {
......
...@@ -10,5 +10,16 @@ def test_inline(): ...@@ -10,5 +10,16 @@ def test_inline():
print(stmt) print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt)) assert(tvm.ir_pass.VerifySSA(stmt))
try:
# pass in int array(wrong argument type)
# must raise an error
stmt = tvm.ir_pass.Inline(
T, [1,2,3], T.op.body, stmt)
assert False
except tvm.TVMError:
pass
if __name__ == "__main__": if __name__ == "__main__":
test_inline() test_inline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment