Unverified Commit 07399e02 by Tianqi Chen Committed by GitHub

[RELAY][OP] Maketuple to be resolved when containing incompleteType (#2031)

parent 866d458c
...@@ -140,7 +140,7 @@ bool ConcatenateRel(const Array<Type>& types, ...@@ -140,7 +140,7 @@ bool ConcatenateRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 2);
const auto* tensor_tuple = types[0].as<TupleTypeNode>(); const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) { if (tensor_tuple == nullptr) {
CHECK(types[0].as<TupleTypeNode>()) CHECK(types[0].as<IncompleteTypeNode>())
<< "cast: expect input type to be TupleType but get " << "cast: expect input type to be TupleType but get "
<< types[0]; << types[0];
return false; return false;
......
...@@ -56,11 +56,31 @@ bool TupleGetItemRel(const Array<Type>& types, ...@@ -56,11 +56,31 @@ bool TupleGetItemRel(const Array<Type>& types,
return true; return true;
} }
bool MakeTupleRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(static_cast<size_t>(num_inputs + 1), types.size());
for (int i = 0; i < num_inputs; ++i) {
if (types[i].as<IncompleteTypeNode>()) return false;
}
Array<Type> fields;
for (int i = 0; i < num_inputs; ++i) {
fields.push_back(types[i]);
}
reporter->Assign(types[num_inputs], TupleTypeNode::make(fields));
return true;
}
TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs);
TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>( .set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
TupleGetItemRel); TupleGetItemRel);
TVM_REGISTER_API("tvm.relay.type_relation.MakeTuple")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
MakeTupleRel);
struct ResolvedTypeInfo { struct ResolvedTypeInfo {
explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args) explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
: checked_type(checked_type), type_args(type_args) {} : checked_type(checked_type), type_args(type_args) {}
...@@ -104,6 +124,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -104,6 +124,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
TypeSolver solver_; TypeSolver solver_;
// relation function // relation function
TypeRelationFn tuple_getitem_rel_; TypeRelationFn tuple_getitem_rel_;
TypeRelationFn make_tuple_rel_;
// Unify two types // Unify two types
Type Unify(const Type& t1, const Type& t2, const Span& span) { Type Unify(const Type& t1, const Type& t2, const Span& span) {
// TODO(tqchen, jroesch): propagate span to solver // TODO(tqchen, jroesch): propagate span to solver
...@@ -154,14 +175,19 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -154,14 +175,19 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
} }
Type VisitExpr_(const TupleNode* op) final { Type VisitExpr_(const TupleNode* op) final {
// TODO(tqchen, jroesch) if (!make_tuple_rel_.defined()) {
// tuple should be a constraint in the type solver make_tuple_rel_ = TypeRelationFn(
// to handle cases where the field type is not known. EnvFunc::Get("tvm.relay.type_relation.MakeTuple").node_);
Array<Type> fields; }
Array<Type> types;
for (Expr field : op->fields) { for (Expr field : op->fields) {
fields.push_back(GetType(field)); types.push_back(GetType(field));
} }
return TupleTypeNode::make(fields); Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
types.push_back(rtype);
solver_.AddConstraint(TypeRelationNode::make(
make_tuple_rel_, types, op->fields.size(), Attrs()));
return rtype;
} }
Type VisitExpr_(const TupleGetItemNode* op) final { Type VisitExpr_(const TupleGetItemNode* op) final {
......
...@@ -87,6 +87,7 @@ def test_concatenate_infer_type(): ...@@ -87,6 +87,7 @@ def test_concatenate_infer_type():
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((n, t, 200)) assert zz.checked_type == relay.TensorType((n, t, 200))
x = relay.exp(x)
z = relay.concatenate((x, y), axis=2) z = relay.concatenate((x, y), axis=2)
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((n, t, 200)) assert zz.checked_type == relay.TensorType((n, t, 200))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment