threading_backend_test.cc 2.45 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#include <atomic>
#include <memory>
#include <thread>

#include <gtest/gtest.h>
#include <tvm/runtime/c_backend_api.h>

constexpr size_t N = 128;

static FTVMParallelLambda atomic_add_task_id = [](int task_id, TVMParallelGroupEnv* penv,
                                                  void* cdata) -> int {
  auto* data = reinterpret_cast<std::atomic<size_t>*>(cdata);
  const size_t N_per_task = (N + penv->num_task - 1) / penv->num_task;
  for (size_t i = task_id * N_per_task; i < N && i < (task_id + 1) * N_per_task; ++i) {
    data->fetch_add(i, std::memory_order_relaxed);
  }
  return 0;
};

TEST(ThreadingBackend, TVMBackendParallelLaunch) {
  std::atomic<size_t> acc(0);
  TVMBackendParallelLaunch(atomic_add_task_id, &acc, 0);
  EXPECT_EQ(acc.load(std::memory_order_relaxed), N * (N - 1) / 2);
}

TEST(ThreadingBackend, TVMBackendParallelLaunchMultipleThreads) {
  // TODO(tulloch) use parameterised tests when available.
  size_t num_jobs_per_thread = 3;
  size_t max_num_threads = 2;

  for (size_t num_threads = 1; num_threads < max_num_threads; ++num_threads) {
    std::vector<std::unique_ptr<std::thread>> ts;
    for (size_t i = 0; i < num_threads; ++i) {
      ts.emplace_back(new std::thread([&]() {
        for (size_t j = 0; j < num_jobs_per_thread; ++j) {
          std::atomic<size_t> acc(0);
          TVMBackendParallelLaunch(atomic_add_task_id, &acc, 0);
          EXPECT_EQ(acc.load(std::memory_order_relaxed), N * (N - 1) / 2);
        }
      }));
    }
    for (auto& t : ts) {
      t->join();
    }
  }
}

int main(int argc, char** argv) {
  testing::InitGoogleTest(&argc, argv);
  testing::FLAGS_gtest_death_test_style = "threadsafe";
  return RUN_ALL_TESTS();
}