#include <dmlc/logging.h> #include <gtest/gtest.h> #include <tvm/runtime/packed_func.h> #include <tvm/tvm.h> #include <tvm/ir.h> TEST(PackedFunc, Basic) { using namespace tvm; using namespace tvm::runtime; int x = 0; void* handle = &x; TVMArray a; Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { CHECK(args.num_args == 3); CHECK(args.values[0].v_float64 == 1.0); CHECK(args.type_codes[0] == kDLFloat); CHECK(args.values[1].v_handle == &a); CHECK(args.type_codes[1] == kArrayHandle); CHECK(args.values[2].v_handle == &x); CHECK(args.type_codes[2] == kHandle); *rv = Var("a"); })(1.0, &a, handle); CHECK(v->name_hint == "a"); } TEST(PackedFunc, Node) { using namespace tvm; using namespace tvm::runtime; Var x; Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { CHECK(args.num_args == 1); CHECK(args.type_codes[0] == kNodeHandle); Var b = args[0]; CHECK(x.same_as(b)); *rv = b; })(x); CHECK(t.same_as(x)); } TEST(PackedFunc, str) { using namespace tvm; using namespace tvm::runtime; PackedFunc([&](TVMArgs args, TVMRetValue* rv) { CHECK(args.num_args == 1); std::string x = args[0]; CHECK(x == "hello"); *rv = x; })("hello"); } TEST(PackedFunc, func) { using namespace tvm; using namespace tvm::runtime; PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0].operator int() + 1; }); // function as arguments int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { PackedFunc f = args[0]; // TVMArgValue -> Arguments as function *rv = f(args[1]).operator int(); })(addone, 1); CHECK_EQ(r0, 2); int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { // TVMArgValue -> TVMRetValue *rv = args[1]; })(2, 100); CHECK_EQ(r1, 100); int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { // re-assignment *rv = args[0]; // TVMRetValue -> Function argument *rv = addone(args[0].operator PackedFunc()(args[1], 1)); })(addone, 100); CHECK_EQ(r2, 102); } TEST(PackedFunc, Expr) { using namespace tvm; using namespace tvm::runtime; // automatic conversion of int to expr PackedFunc addone([](TVMArgs args, TVMRetValue* rv) { Expr x = args[0]; *rv = x.as<tvm::ir::IntImm>()->value + 1; }); int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { PackedFunc f = args[0]; // TVMArgValue -> Arguments as function *rv = f(args[1]).operator int(); })(addone, 1); CHECK_EQ(r0, 2); } TEST(PackedFunc, Type) { using namespace tvm; using namespace tvm::runtime; auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) { Type x = args[0]; *rv = x; }); auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); CHECK(get_type("int32").operator Type() == Int(32)); CHECK(get_type("float").operator Type() == Float(32)); CHECK(get_type2("float32x2").operator Type() == Float(32, 2)); } // new namespoace namespace test { // register int vector as extension type using IntVector = std::vector<int>; } // namespace test namespace tvm { namespace runtime { template<> struct extension_class_info<test::IntVector> { static const int code = kExtBegin + 1; }; } // runtime } // tvm // do registration, this need to be in cc file TVM_REGISTER_EXT_TYPE(test::IntVector); TEST(PackedFunc, ExtensionType) { using namespace tvm; using namespace tvm::runtime; // note: class are copy by value. test::IntVector vec{1, 2, 4}; auto copy_vec = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { // copy by value const test::IntVector& v = args[0].AsExtension<test::IntVector>(); CHECK(&v == &vec); test::IntVector v2 = args[0]; CHECK_EQ(v2.size(), 3U); CHECK_EQ(v[2], 4); // return copy by value *rv = v2; }); auto pass_vec = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { // copy by value *rv = args[0]; }); test::IntVector vret1 = copy_vec(vec); test::IntVector vret2 = pass_vec(copy_vec(vec)); CHECK_EQ(vret1.size(), 3U); CHECK_EQ(vret2.size(), 3U); CHECK_EQ(vret1[2], 4); CHECK_EQ(vret2[2], 4); } int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); }