Commit d0fe532e by Animesh Jain Committed by Haichen Shen

[Relay][Compile_engine] Int64 shape handling for outputs. (#4031)

parent 1bff2c89
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -219,6 +219,25 @@ class ScheduleGetter :
CHECK_EQ(call_node->args.size(), 1U)
<< "Only allow function with a single tuple input";
}
// Prepare the call_node->checked_type(). For the call node inputs, we ensure that the shape is
// Int32. Following code ensures the same for the output as well.
// TODO(@icemelon): Support recursive tuple
Type call_node_type = call_node->checked_type();
if (const auto* tt = call_node->checked_type().as<TensorTypeNode>()) {
call_node_type = TensorTypeNode::make(GetShape(tt->shape), tt->dtype);
} else if (const auto* tuple_t = call_node->checked_type().as<TupleTypeNode>()) {
std::vector<Type> new_fields;
for (auto field : tuple_t->fields) {
if (const auto* tt = field.as<TensorTypeNode>()) {
new_fields.push_back(TensorTypeNode::make(GetShape(tt->shape), tt->dtype));
} else {
new_fields.push_back(field);
}
}
call_node_type = TupleTypeNode::make(new_fields);
}
CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
......@@ -232,7 +251,7 @@ class ScheduleGetter :
Operation(), 0));
} else {
outputs = fcompute[op](call_node->attrs, inputs,
call_node->checked_type(), target_);
call_node_type, target_);
}
int op_pattern = fpattern[op];
......
......@@ -79,8 +79,23 @@ def test_compile_tuple_dup():
relay.build(relay.Module.from_expr(f), 'llvm')
def test_compile_full():
# Shape calculations can happen in int64. The test checks that full operator
# can handle when shapes are not int32
shape = (tvm.expr.IntImm('int32', 1),
tvm.expr.IntImm('int64', 16),
tvm.expr.IntImm('int64', 16),
tvm.expr.IntImm('int32', 64))
output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32')
f = relay.Function([], output)
mod = relay.Module.from_expr(f)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
relay.build(mod, 'llvm')
if __name__ == "__main__":
test_compile_engine()
test_compile_placeholder_bypass()
test_compile_injective_with_tuple()
test_compile_tuple_dup()
test_compile_full()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment