Commit d0fe532e by Animesh Jain Committed by Haichen Shen

[Relay][Compile_engine] Int64 shape handling for outputs. (#4031)

parent 1bff2c89
...@@ -219,6 +219,25 @@ class ScheduleGetter : ...@@ -219,6 +219,25 @@ class ScheduleGetter :
CHECK_EQ(call_node->args.size(), 1U) CHECK_EQ(call_node->args.size(), 1U)
<< "Only allow function with a single tuple input"; << "Only allow function with a single tuple input";
} }
// Prepare the call_node->checked_type(). For the call node inputs, we ensure that the shape is
// Int32. Following code ensures the same for the output as well.
// TODO(@icemelon): Support recursive tuple
Type call_node_type = call_node->checked_type();
if (const auto* tt = call_node->checked_type().as<TensorTypeNode>()) {
call_node_type = TensorTypeNode::make(GetShape(tt->shape), tt->dtype);
} else if (const auto* tuple_t = call_node->checked_type().as<TupleTypeNode>()) {
std::vector<Type> new_fields;
for (auto field : tuple_t->fields) {
if (const auto* tt = field.as<TensorTypeNode>()) {
new_fields.push_back(TensorTypeNode::make(GetShape(tt->shape), tt->dtype));
} else {
new_fields.push_back(field);
}
}
call_node_type = TupleTypeNode::make(new_fields);
}
CHECK(call_node->op.as<OpNode>()) CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops"; << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op); Op op = Downcast<Op>(call_node->op);
...@@ -232,7 +251,7 @@ class ScheduleGetter : ...@@ -232,7 +251,7 @@ class ScheduleGetter :
Operation(), 0)); Operation(), 0));
} else { } else {
outputs = fcompute[op](call_node->attrs, inputs, outputs = fcompute[op](call_node->attrs, inputs,
call_node->checked_type(), target_); call_node_type, target_);
} }
int op_pattern = fpattern[op]; int op_pattern = fpattern[op];
......
...@@ -79,8 +79,23 @@ def test_compile_tuple_dup(): ...@@ -79,8 +79,23 @@ def test_compile_tuple_dup():
relay.build(relay.Module.from_expr(f), 'llvm') relay.build(relay.Module.from_expr(f), 'llvm')
def test_compile_full():
# Shape calculations can happen in int64. The test checks that full operator
# can handle when shapes are not int32
shape = (tvm.expr.IntImm('int32', 1),
tvm.expr.IntImm('int64', 16),
tvm.expr.IntImm('int64', 16),
tvm.expr.IntImm('int32', 64))
output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32')
f = relay.Function([], output)
mod = relay.Module.from_expr(f)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
relay.build(mod, 'llvm')
if __name__ == "__main__": if __name__ == "__main__":
test_compile_engine() test_compile_engine()
test_compile_placeholder_bypass() test_compile_placeholder_bypass()
test_compile_injective_with_tuple() test_compile_injective_with_tuple()
test_compile_tuple_dup() test_compile_tuple_dup()
test_compile_full()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment