/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include <dmlc/logging.h> #include <gtest/gtest.h> #include <tvm/runtime/packed_func.h> #include <tvm/runtime/registry.h> #include <tvm/tir/expr.h> TEST(PackedFunc, Basic) { using namespace tvm; using namespace tvm::tir; using namespace tvm::runtime; int x = 0; void* handle = &x; DLTensor a; Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { CHECK(args.num_args == 3); CHECK(args.values[0].v_float64 == 1.0); CHECK(args.type_codes[0] == kDLFloat); CHECK(args.values[1].v_handle == &a); CHECK(args.type_codes[1] == kTVMDLTensorHandle); CHECK(args.values[2].v_handle == &x); CHECK(args.type_codes[2] == kTVMOpaqueHandle); *rv = Var("a"); })(1.0, &a, handle); CHECK(v->name_hint == "a"); } TEST(PackedFunc, Node) { using namespace tvm; using namespace tvm::tir; using namespace tvm::runtime; Var x; Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { CHECK(args.num_args == 1); CHECK(args.type_codes[0] == kTVMObjectHandle); Var b = args[0]; CHECK(x.same_as(b)); *rv = b; })(x); CHECK(t.same_as(x)); } TEST(PackedFunc, NDArray) { using namespace tvm; using namespace tvm::runtime; auto x = NDArray::Empty( {}, String2DLDataType("float32"), TVMContext{kDLCPU, 0}); reinterpret_cast<float*>(x->data)[0] = 10.0f; CHECK(x.use_count() == 1); PackedFunc forward([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); NDArray ret = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { NDArray y = args[0]; DLTensor* ptr = args[0]; CHECK(ptr == x.operator->()); CHECK(x.same_as(y)); CHECK(x.use_count() == 2); *rv = forward(y); })(x); CHECK(ret.use_count() == 2); CHECK(ret.same_as(x)); } TEST(PackedFunc, str) { using namespace tvm; using namespace tvm::runtime; PackedFunc([&](TVMArgs args, TVMRetValue* rv) { CHECK(args.num_args == 1); std::string x = args[0]; CHECK(x == "hello"); *rv = x; })("hello"); } TEST(PackedFunc, func) { using namespace tvm; using namespace tvm::runtime; PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0].operator int() + 1; }); // function as arguments int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { PackedFunc f = args[0]; // TVMArgValue -> Arguments as function *rv = f(args[1]).operator int(); })(addone, 1); CHECK_EQ(r0, 2); int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { // TVMArgValue -> TVMRetValue *rv = args[1]; })(2, 100); CHECK_EQ(r1, 100); int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { // re-assignment *rv = args[0]; // TVMRetValue -> Function argument *rv = addone(args[0].operator PackedFunc()(args[1], 1)); })(addone, 100); CHECK_EQ(r2, 102); } TEST(PackedFunc, Expr) { using namespace tvm; using namespace tvm::runtime; // automatic conversion of int to expr PackedFunc addone([](TVMArgs args, TVMRetValue* rv) { PrimExpr x = args[0]; *rv = x.as<tvm::tir::IntImmNode>()->value + 1; }); int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { PackedFunc f = args[0]; // TVMArgValue -> Arguments as function *rv = f(args[1]).operator int(); })(addone, 1); CHECK_EQ(r0, 2); } TEST(PackedFunc, Type) { using namespace tvm; using namespace tvm::runtime; auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) { DataType x = args[0]; *rv = x; }); auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); CHECK(get_type("int32").operator DataType() == DataType::Int(32)); CHECK(get_type("float").operator DataType() == DataType::Float(32)); CHECK(get_type2("float32x2").operator DataType() == DataType::Float(32, 2)); } TEST(TypedPackedFunc, HighOrder) { using namespace tvm; using namespace tvm::runtime; using Int1Func = TypedPackedFunc<int(int)>; using Int2Func = TypedPackedFunc<int(int, int)>; using BindFunc = TypedPackedFunc<Int1Func(Int2Func, int value)>; BindFunc ftyped; ftyped = [](Int2Func f1, int value) -> Int1Func { auto binded = [f1, value](int x) { return f1(value, x); }; Int1Func x(binded); return x; }; auto add = [](int x, int y) { return x + y; }; CHECK_EQ(ftyped(Int2Func(add), 1)(2), 3); PackedFunc f = ftyped(Int2Func(add), 1); CHECK_EQ(f(3).operator int(), 4); // call the type erased version. Int1Func f1 = ftyped.packed()(Int2Func(add), 1); CHECK_EQ(f1(3), 4); } TEST(TypedPackedFunc, Deduce) { using namespace tvm::runtime; using tvm::runtime::detail::function_signature; TypedPackedFunc<int(float)> x; auto f = [](int x) -> int { return x + 1; }; std::function<void(float)> y; static_assert(std::is_same<function_signature<decltype(x)>::FType, int(float)>::value, "invariant1"); static_assert(std::is_same<function_signature<decltype(f)>::FType, int(int)>::value, "invariant2"); static_assert(std::is_same<function_signature<decltype(y)>::FType, void(float)>::value, "invariant3"); } TEST(PackedFunc, ObjectConversion) { using namespace tvm; using namespace tvm::tir; using namespace tvm::runtime; TVMRetValue rv; auto x = NDArray::Empty( {}, String2DLDataType("float32"), TVMContext{kDLCPU, 0}); // assign null rv = ObjectRef(); CHECK_EQ(rv.type_code(), kTVMNullptr); // Can assign NDArray to ret type rv = x; CHECK_EQ(rv.type_code(), kTVMNDArrayHandle); // Even if we assign base type it still shows as NDArray rv = ObjectRef(x); CHECK_EQ(rv.type_code(), kTVMNDArrayHandle); // Check convert back CHECK(rv.operator NDArray().same_as(x)); CHECK(rv.operator ObjectRef().same_as(x)); CHECK(!rv.IsObjectRef<PrimExpr>()); auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args[0].type_code(), kTVMNDArrayHandle); CHECK(args[0].operator NDArray().same_as(x)); CHECK(args[0].operator ObjectRef().same_as(x)); CHECK(args[1].operator ObjectRef().get() == nullptr); CHECK(args[1].operator NDArray().get() == nullptr); CHECK(args[1].operator Module().get() == nullptr); CHECK(args[1].operator Array<NDArray>().get() == nullptr); CHECK(!args[0].IsObjectRef<PrimExpr>()); }); pf1(x, ObjectRef()); pf1(ObjectRef(x), NDArray()); // testcases for modules auto* pf = tvm::runtime::Registry::Get("runtime.SourceModuleCreate"); CHECK(pf != nullptr); Module m = (*pf)("", "xyz"); rv = m; CHECK_EQ(rv.type_code(), kTVMModuleHandle); // Even if we assign base type it still shows as NDArray rv = ObjectRef(m); CHECK_EQ(rv.type_code(), kTVMModuleHandle); // Check convert back CHECK(rv.operator Module().same_as(m)); CHECK(rv.operator ObjectRef().same_as(m)); CHECK(!rv.IsObjectRef<NDArray>()); auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args[0].type_code(), kTVMModuleHandle); CHECK(args[0].operator Module().same_as(m)); CHECK(args[0].operator ObjectRef().same_as(m)); CHECK(args[1].operator ObjectRef().get() == nullptr); CHECK(args[1].operator NDArray().get() == nullptr); CHECK(args[1].operator Module().get() == nullptr); CHECK(!args[0].IsObjectRef<PrimExpr>()); }); pf2(m, ObjectRef()); pf2(ObjectRef(m), Module()); } int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); }