Commit 95ab85d0 by Jared Roesch Committed by Tianqi Chen

[Relay][VM] Fix code generation for packed functions + tuples (#3287)

parent f2ddb196
......@@ -334,15 +334,42 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
return Instruction::AllocTensor(last_register, dltype, NewRegister());
}
void EmitInvokePrimitive(const Function& func, std::vector<Index> args_registers,
void EmitInvokePrimitive(const Function& func,
const std::vector<Index>& args_registers,
const Type& ret_type) {
std::vector<Index> unpacked_arg_regs;
std::vector<Instruction> allocs;
size_t return_num = 0;
// Arity calculation must flatten tuples.
size_t arity = 0;
CHECK_EQ(func->params.size(), args_registers.size());
for (size_t i = 0; i < func->params.size(); i++) {
auto ty = func->params[i]->checked_type();
if (ty.as<TensorTypeNode>()) {
unpacked_arg_regs.push_back(args_registers[i]);
arity += 1;
} else if (auto tuple_ty = ty.as<TupleTypeNode>()) {
for (size_t f = 0; f < tuple_ty->fields.size(); f++) {
const auto& field = tuple_ty->fields[f];
CHECK(field.as<TensorTypeNode>())
<< "only supports non-nested tuples currently "
<< "found " << field;
auto dst = NewRegister();
Emit(Instruction::GetField(args_registers[i], f, dst));
unpacked_arg_regs.push_back(dst);
}
arity += tuple_ty->fields.size();
} else {
LOG(FATAL) << "unsupported parameter type " << ty;
}
}
size_t return_val_count = 0;
if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
// Allocate space for the return tensor.
auto alloc = AllocTensorFromType(ttype);
allocs.push_back(alloc);
return_num = 1;
return_val_count = 1;
} else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
std::vector<Index> fields_registers;
......@@ -352,14 +379,15 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
allocs.push_back(AllocTensorFromType(f_type));
fields_registers.push_back(allocs.back().dst);
}
return_num = ttype->fields.size();
return_val_count = ttype->fields.size();
} else {
LOG(FATAL) << "Unsupported return value type";
}
arity += return_val_count;
for (auto& alloc : allocs) {
Emit(alloc);
args_registers.push_back(alloc.dst);
unpacked_arg_regs.push_back(alloc.dst);
}
// Next generate the invoke instruction.
......@@ -378,17 +406,15 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
op_index = seen_funcs[cfunc->funcs[0]];
}
// If Tensor, 1
// If Tuple, size of tuple
size_t arity = func->params.size() + return_num;
Emit(Instruction::InvokePacked(op_index, arity, return_num, args_registers));
if (return_num > 1) {
Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
if (return_val_count > 1) {
// return value is a tuple, we need to create a tuple
std::vector<Index> fields_registers;
for (size_t i = func->params.size(); i < arity; ++i) {
fields_registers.push_back(args_registers[i]);
for (size_t i = arity - return_val_count; i < arity; ++i) {
fields_registers.push_back(unpacked_arg_regs[i]);
}
Emit(Instruction::AllocDatatype(0, return_num, fields_registers, NewRegister()));
Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister()));
}
}
......
......@@ -49,6 +49,17 @@ def test_split():
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
def test_split_no_fuse():
x = relay.var('x', shape=(12,))
y = relay.split(x, 3, axis=0).astuple()
z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0)
z = relay.annotation.stop_fusion(z)
f = relay.Function([x], z)
x_data = np.random.rand(12,).astype('float32')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
def test_id():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x)
......@@ -259,6 +270,8 @@ if __name__ == "__main__":
test_tuple_second()
test_let_scalar()
test_let_tensor()
test_split()
test_split_no_fuse()
# TODO(@jroesch): restore when match is supported
# test_list_constructor()
test_closure()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment