Unverified Commit 07399e02 by Tianqi Chen Committed by GitHub

[RELAY][OP] Maketuple to be resolved when containing incompleteType (#2031)

parent 866d458c
......@@ -140,7 +140,7 @@ bool ConcatenateRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2);
const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) {
CHECK(types[0].as<TupleTypeNode>())
CHECK(types[0].as<IncompleteTypeNode>())
<< "cast: expect input type to be TupleType but get "
<< types[0];
return false;
......
......@@ -56,11 +56,31 @@ bool TupleGetItemRel(const Array<Type>& types,
return true;
}
bool MakeTupleRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(static_cast<size_t>(num_inputs + 1), types.size());
for (int i = 0; i < num_inputs; ++i) {
if (types[i].as<IncompleteTypeNode>()) return false;
}
Array<Type> fields;
for (int i = 0; i < num_inputs; ++i) {
fields.push_back(types[i]);
}
reporter->Assign(types[num_inputs], TupleTypeNode::make(fields));
return true;
}
TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs);
TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
TupleGetItemRel);
TVM_REGISTER_API("tvm.relay.type_relation.MakeTuple")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
MakeTupleRel);
struct ResolvedTypeInfo {
explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
: checked_type(checked_type), type_args(type_args) {}
......@@ -104,6 +124,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
TypeSolver solver_;
// relation function
TypeRelationFn tuple_getitem_rel_;
TypeRelationFn make_tuple_rel_;
// Unify two types
Type Unify(const Type& t1, const Type& t2, const Span& span) {
// TODO(tqchen, jroesch): propagate span to solver
......@@ -154,14 +175,19 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
}
Type VisitExpr_(const TupleNode* op) final {
// TODO(tqchen, jroesch)
// tuple should be a constraint in the type solver
// to handle cases where the field type is not known.
Array<Type> fields;
if (!make_tuple_rel_.defined()) {
make_tuple_rel_ = TypeRelationFn(
EnvFunc::Get("tvm.relay.type_relation.MakeTuple").node_);
}
Array<Type> types;
for (Expr field : op->fields) {
fields.push_back(GetType(field));
types.push_back(GetType(field));
}
return TupleTypeNode::make(fields);
Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
types.push_back(rtype);
solver_.AddConstraint(TypeRelationNode::make(
make_tuple_rel_, types, op->fields.size(), Attrs()));
return rtype;
}
Type VisitExpr_(const TupleGetItemNode* op) final {
......
......@@ -87,6 +87,7 @@ def test_concatenate_infer_type():
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((n, t, 200))
x = relay.exp(x)
z = relay.concatenate((x, y), axis=2)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((n, t, 200))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment