Commit eee0ebef by tqchen

Stronger type checker during conversion

parent 57a74936
......@@ -27,6 +27,7 @@ using Halide::IR::FunctionRef;
using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt;
using Halide::Internal::IRPrinter;
using Halide::Internal::Variable;
/*! \brief a named variable in TVM */
class Var : public Halide::VarExpr {
......@@ -35,6 +36,9 @@ class Var : public Halide::VarExpr {
Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
/*! \brief type indicate the container type */
using ContainerType = Variable;
};
......
......@@ -83,7 +83,6 @@ using Halide::Internal::UIntImm;
using Halide::Internal::FloatImm;
using Halide::Internal::StringImm;
using Halide::Internal::Cast;
using Halide::Internal::Variable;
using Halide::Internal::Add;
using Halide::Internal::Sub;
using Halide::Internal::Mul;
......
......@@ -10,4 +10,5 @@ from . import ir_pass
from . import collections
from . import schedule
from ._base import TVMError
from .function import *
......@@ -54,6 +54,43 @@ inline const char* TypeId2Str(ArgVariantID type_id) {
}
}
template<typename T>
struct NodeTypeChecker {
static inline void Check(Node* sptr) {
using ContainerType = typename T::ContainerType;
// use dynamic RTTI for safety
CHECK(dynamic_cast<ContainerType*>(sptr))
<< "wrong type specified, expected " << typeid(ContainerType).name();
}
};
template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline void Check(Node* sptr) {
// use dynamic RTTI for safety
CHECK(sptr != nullptr && sptr->is_type<ArrayNode>())
<< "wrong type specified, expected Array";
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
NodeTypeChecker<T>::Check(p.get());
}
}
};
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline void Check(Node* sptr) {
// use dynamic RTTI for safety
CHECK(sptr != nullptr && sptr->is_type<MapNode>())
<< "wrong type specified, expected Map";
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
NodeTypeChecker<K>::Check(kv.first.get());
NodeTypeChecker<V>::Check(kv.second.get());
}
}
};
/*! \brief Variant container for API calls */
class APIVariantValue {
public:
......@@ -109,13 +146,11 @@ class APIVariantValue {
return operator=(Type2String(value));
}
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
inline operator T() const {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
// use dynamic RTTI for safety
CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get()))
<< "wrong type specified, expected " << typeid(typename T::ContainerType).name();
NodeTypeChecker<T>::Check(sptr.get());
return T(sptr);
}
inline operator Expr() const {
......
......@@ -10,5 +10,16 @@ def test_inline():
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))
try:
# pass in int array(wrong argument type)
# must raise an error
stmt = tvm.ir_pass.Inline(
T, [1,2,3], T.op.body, stmt)
assert False
except tvm.TVMError:
pass
if __name__ == "__main__":
test_inline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment