Unverified Commit f1e5b3fb by Yen-Chen Lin Committed by GitHub

Merge pull request #35 from WillBrennan/feature/replace-searchsorted

replace searchsorted with in-built function
parents ec26d1c1 434e7814
...@@ -14,9 +14,6 @@ This project is a faithful PyTorch implementation of [NeRF](http://www.matthewta ...@@ -14,9 +14,6 @@ This project is a faithful PyTorch implementation of [NeRF](http://www.matthewta
git clone https://github.com/yenchenlin/nerf-pytorch.git git clone https://github.com/yenchenlin/nerf-pytorch.git
cd nerf-pytorch cd nerf-pytorch
pip install -r requirements.txt pip install -r requirements.txt
cd torchsearchsorted
pip install .
cd ../
``` ```
<details> <details>
......
...@@ -4,9 +4,6 @@ import torch.nn as nn ...@@ -4,9 +4,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
# TODO: remove this dependency
from torchsearchsorted import searchsorted
# Misc # Misc
img2mse = lambda x, y : torch.mean((x - y) ** 2) img2mse = lambda x, y : torch.mean((x - y) ** 2)
...@@ -223,7 +220,7 @@ def sample_pdf(bins, weights, N_samples, det=False, pytest=False): ...@@ -223,7 +220,7 @@ def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
# Invert CDF # Invert CDF
u = u.contiguous() u = u.contiguous()
inds = searchsorted(cdf, u, side='right') inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds-1), inds-1) below = torch.max(torch.zeros_like(inds-1), inds-1)
above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
......
# Prerequisites
*.d
# Object files
*.o
*.ko
*.obj
*.elf
# Linker output
*.ilk
*.map
*.exp
# Precompiled Headers
*.gch
*.pch
# Libraries
*.lib
*.a
*.la
*.lo
# Shared objects (inc. Windows DLLs)
*.dll
*.so
*.so.*
*.dylib
# Executables
*.exe
*.out
*.app
*.i*86
*.x86_64
*.hex
# Debug files
*.dSYM/
*.su
*.idb
*.pdb
# Kernel Module Compile Results
*.mod*
*.cmd
.tmp_versions/
modules.order
Module.symvers
Mkfile.old
dkms.conf
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
BSD 3-Clause License
Copyright (c) 2019, Inria (Antoine Liutkus)
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Pytorch Custom CUDA kernel for searchsorted
This repository is an implementation of the searchsorted function to work for pytorch CUDA Tensors. Initially derived from the great [C extension tutorial](https://github.com/chrischoy/pytorch-custom-cuda-tutorial), but totally changed since then because building C extensions is not available anymore on pytorch 1.0.
> Warnings:
> * only works with pytorch > v1.3 and CUDA >= v10.1
> * **NOTE** When using `searchsorted()` for practical applications, tensors need to be contiguous in memory. This can be easily achieved by calling `tensor.contiguous()` on the input tensors. Failing to do so _will_ lead to inconsistent results across applications.
## Description
Implements a function `searchsorted(a, v, out, side)` that works just like the [numpy version](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted) except that `a` and `v` are matrices.
* `a` is of shape either `(1, ncols_a)` or `(nrows, ncols_a)`, and is contiguous in memory (do `a.contiguous()` to ensure this).
* `v` is of shape either `(1, ncols_v)` or `(nrows, ncols_v)`, and is contiguous in memory (do `v.contiguous()` to ensure this).
* `out` is either `None` or of shape `(nrows, ncols_v)`. If provided and of the right shape, the result is put there. This is to avoid costly memory allocations if the user already did it. If provided, `out` should be contiguous in memory too (do `out.contiguous()` to ensure this).
* `side` is either "left" or "right". See the [numpy doc](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted). Please not that the current implementation *does not correctly handle this parameter*. Help welcome to improve the speed of [this PR](https://github.com/aliutkus/torchsearchsorted/pull/7)
the output is of size as `(nrows, ncols_v)`. If all input tensors are on GPU, a cuda version will be called. Otherwise, it will be on CPU.
**Disclaimers**
* This function has not been heavily tested. Use at your own risks
* When `a` is not sorted, the results vary from numpy's version. But I decided not to care about this because the function should not be called in this case.
* In some cases, the results vary from numpy's version. However, as far as I could see, this only happens when values are equal, which means we actually don't care about the order in which this value is added. I decided not to care about this also.
* vectors have to be contiguous for torchsearchsorted to give consistant results. use `.contiguous()` on all tensor arguments before calling
## Installation
Just `pip install .`, in the root folder of this repo. This will compile
and install the torchsearchsorted module.
be careful that sometimes, `nvcc` needs versions of `gcc` and `g++` that are older than those found by default on the system. If so, just create symbolic links to the right versions in your cuda/bin folder (where `nvcc` is)
For instance, on my machine, I had `gcc` and `g++` v9 installed, but `nvcc` required v8.
So I had to do:
> sudo apt-get install g++-8 gcc-8
> sudo ln -s /usr/bin/gcc-8 /usr/local/cuda-10.1/bin/gcc
> sudo ln -s /usr/bin/g++-8 /usr/local/cuda-10.1/bin/g++
be careful that you need pytorch to be installed on your system. The code was tested on pytorch v1.3
## Usage
Just import the torchsearchsorted package after installation. I typically do:
```
from torchsearchsorted import searchsorted
```
## Testing
Under the `examples` subfolder, you may:
1. try `python test.py` with `torch` available.
```
Looking for 50000x1000 values in 50000x300 entries
NUMPY: searchsorted in 4851.592ms
CPU: searchsorted in 4805.432ms
difference between CPU and NUMPY: 0.000
GPU: searchsorted in 1.055ms
difference between GPU and NUMPY: 0.000
Looking for 50000x1000 values in 50000x300 entries
NUMPY: searchsorted in 4333.964ms
CPU: searchsorted in 4753.958ms
difference between CPU and NUMPY: 0.000
GPU: searchsorted in 0.391ms
difference between GPU and NUMPY: 0.000
```
The first run comprises the time of allocation, while the second one does not.
2. You may also use the nice `benchmark.py` code written by [@baldassarreFe](https://github.com/baldassarreFe), that tests `searchsorted` on many runs:
```
Benchmark searchsorted:
- a [5000 x 300]
- v [5000 x 100]
- reporting fastest time of 20 runs
- each run executes searchsorted 100 times
Numpy: 4.6302046799100935
CPU: 5.041533078998327
CUDA: 0.0007955809123814106
```
import timeit
import torch
import numpy as np
from torchsearchsorted import searchsorted, numpy_searchsorted
B = 5_000
A = 300
V = 100
repeats = 20
number = 100
print(
f'Benchmark searchsorted:',
f'- a [{B} x {A}]',
f'- v [{B} x {V}]',
f'- reporting fastest time of {repeats} runs',
f'- each run executes searchsorted {number} times',
sep='\n',
end='\n\n'
)
def get_arrays():
a = np.sort(np.random.randn(B, A), axis=1)
v = np.random.randn(B, V)
out = np.empty_like(v, dtype=np.long)
return a, v, out
def get_tensors(device):
a = torch.sort(torch.randn(B, A, device=device), dim=1)[0]
v = torch.randn(B, V, device=device)
out = torch.empty(B, V, device=device, dtype=torch.long)
if torch.cuda.is_available():
torch.cuda.synchronize()
return a, v, out
def searchsorted_synchronized(a,v,out=None,side='left'):
out = searchsorted(a,v,out,side)
torch.cuda.synchronize()
return out
numpy = timeit.repeat(
stmt="numpy_searchsorted(a, v, side='left')",
setup="a, v, out = get_arrays()",
globals=globals(),
repeat=repeats,
number=number
)
print('Numpy: ', min(numpy), sep='\t')
cpu = timeit.repeat(
stmt="searchsorted(a, v, out, side='left')",
setup="a, v, out = get_tensors(device='cpu')",
globals=globals(),
repeat=repeats,
number=number
)
print('CPU: ', min(cpu), sep='\t')
if torch.cuda.is_available():
gpu = timeit.repeat(
stmt="searchsorted_synchronized(a, v, out, side='left')",
setup="a, v, out = get_tensors(device='cuda')",
globals=globals(),
repeat=repeats,
number=number
)
print('CUDA: ', min(gpu), sep='\t')
import torch
from torchsearchsorted import searchsorted, numpy_searchsorted
import time
if __name__ == '__main__':
# defining the number of tests
ntests = 2
# defining the problem dimensions
nrows_a = 50000
nrows_v = 50000
nsorted_values = 300
nvalues = 1000
# defines the variables. The first run will comprise allocation, the
# further ones will not
test_GPU = None
test_CPU = None
for ntest in range(ntests):
print("\nLooking for %dx%d values in %dx%d entries" % (nrows_v, nvalues,
nrows_a,
nsorted_values))
side = 'right'
# generate a matrix with sorted rows
a = torch.randn(nrows_a, nsorted_values, device='cpu')
a = torch.sort(a, dim=1)[0]
# generate a matrix of values to searchsort
v = torch.randn(nrows_v, nvalues, device='cpu')
# a = torch.tensor([[0., 1.]])
# v = torch.tensor([[1.]])
t0 = time.time()
test_NP = torch.tensor(numpy_searchsorted(a, v, side))
print('NUMPY: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
t0 = time.time()
test_CPU = searchsorted(a, v, test_CPU, side)
print('CPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
# compute the difference between both
error_CPU = torch.norm(test_NP.double()
- test_CPU.double()).numpy()
if error_CPU:
import ipdb; ipdb.set_trace()
print(' difference between CPU and NUMPY: %0.3f' % error_CPU)
if not torch.cuda.is_available():
print('CUDA is not available on this machine, cannot go further.')
continue
else:
# now do the CPU
a = a.to('cuda')
v = v.to('cuda')
torch.cuda.synchronize()
# launch searchsorted on those
t0 = time.time()
test_GPU = searchsorted(a, v, test_GPU, side)
torch.cuda.synchronize()
print('GPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
# compute the difference between both
error_CUDA = torch.norm(test_NP.to('cuda').double()
- test_GPU.double()).cpu().numpy()
print(' difference between GPU and NUMPY: %0.3f' % error_CUDA)
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME
from torch.utils.cpp_extension import CppExtension, CUDAExtension
# In any case, include the CPU version
modules = [
CppExtension('torchsearchsorted.cpu',
['src/cpu/searchsorted_cpu_wrapper.cpp']),
]
# If nvcc is available, add the CUDA extension
if CUDA_HOME:
modules.append(
CUDAExtension('torchsearchsorted.cuda',
['src/cuda/searchsorted_cuda_wrapper.cpp',
'src/cuda/searchsorted_cuda_kernel.cu'])
)
tests_require = [
'pytest',
]
# Now proceed to setup
setup(
name='torchsearchsorted',
version='1.1',
description='A searchsorted implementation for pytorch',
keywords='searchsorted',
author='Antoine Liutkus',
author_email='antoine.liutkus@inria.fr',
packages=find_packages(where='src'),
package_dir={"": "src"},
ext_modules=modules,
tests_require=tests_require,
extras_require={
'test': tests_require,
},
cmdclass={
'build_ext': BuildExtension
}
)
#include "searchsorted_cpu_wrapper.h"
#include <stdio.h>
template<typename scalar_t>
int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
{
/* Evaluates whether a[row,col] < val <= a[row, col+1]*/
if (col == ncol - 1)
{
// special case: we are on the right border
if (a[row * ncol + col] <= val){
return 1;}
else {
return -1;}
}
bool is_lower;
bool is_next_higher;
if (side_left) {
// a[row, col] < v <= a[row, col+1]
is_lower = (a[row * ncol + col] < val);
is_next_higher = (a[row*ncol + col + 1] >= val);
} else {
// a[row, col] <= v < a[row, col+1]
is_lower = (a[row * ncol + col] <= val);
is_next_higher = (a[row * ncol + col + 1] > val);
}
if (is_lower && is_next_higher) {
// we found the right spot
return 0;
} else if (is_lower) {
// answer is on the right side
return 1;
} else {
// answer is on the left side
return -1;
}
}
template<typename scalar_t>
int64_t binary_search(scalar_t*a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
{
/* Look for the value `val` within row `row` of matrix `a`, which
has `ncol` columns.
the `a` matrix is assumed sorted in increasing order, row-wise
returns:
* -1 if `val` is smaller than the smallest value found within that row of `a`
* `ncol` - 1 if `val` is larger than the largest element of that row of `a`
* Otherwise, return the column index `res` such that:
- a[row, col] < val <= a[row, col+1]. (if side_left), or
- a[row, col] < val <= a[row, col+1] (if not side_left).
*/
//start with left at 0 and right at number of columns of a
int64_t right = ncol;
int64_t left = 0;
while (right >= left) {
// take the midpoint of current left and right cursors
int64_t mid = left + (right-left)/2;
// check the relative position of val: are we good here ?
int rel_pos = eval(val, a, row, mid, ncol, side_left);
// we found the point
if(rel_pos == 0) {
return mid;
} else if (rel_pos > 0) {
if (mid==ncol-1){return ncol-1;}
// the answer is on the right side
left = mid;
} else {
if (mid==0){return -1;}
right = mid;
}
}
return -1;
}
void searchsorted_cpu_wrapper(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left)
{
// Get the dimensions
auto nrow_a = a.size(/*dim=*/0);
auto ncol_a = a.size(/*dim=*/1);
auto nrow_v = v.size(/*dim=*/0);
auto ncol_v = v.size(/*dim=*/1);
auto nrow_res = fmax(nrow_a, nrow_v);
//auto acc_v = v.accessor<float, 2>();
//auto acc_res = res.accessor<float, 2>();
AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cpu", [&] {
scalar_t* a_data = a.data_ptr<scalar_t>();
scalar_t* v_data = v.data_ptr<scalar_t>();
int64_t* res_data = res.data<int64_t>();
for (int64_t row = 0; row < nrow_res; row++)
{
for (int64_t col = 0; col < ncol_v; col++)
{
// get the value to look for
int64_t row_in_v = (nrow_v == 1) ? 0 : row;
int64_t row_in_a = (nrow_a == 1) ? 0 : row;
int64_t idx_in_v = row_in_v * ncol_v + col;
int64_t idx_in_res = row * ncol_v + col;
// apply binary search
res_data[idx_in_res] = (binary_search(a_data, row_in_a, v_data[idx_in_v], ncol_a, side_left) + 1);
}
}
});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("searchsorted_cpu_wrapper", &searchsorted_cpu_wrapper, "searchsorted (CPU)");
}
#ifndef _SEARCHSORTED_CPU
#define _SEARCHSORTED_CPU
#include <torch/extension.h>
void searchsorted_cpu_wrapper(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left);
#endif
\ No newline at end of file
#include "searchsorted_cuda_kernel.h"
template <typename scalar_t>
__device__
int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
{
/* Evaluates whether a[row,col] < val <= a[row, col+1]*/
if (col == ncol - 1)
{
// special case: we are on the right border
if (a[row * ncol + col] <= val){
return 1;}
else {
return -1;}
}
bool is_lower;
bool is_next_higher;
if (side_left) {
// a[row, col] < v <= a[row, col+1]
is_lower = (a[row * ncol + col] < val);
is_next_higher = (a[row*ncol + col + 1] >= val);
} else {
// a[row, col] <= v < a[row, col+1]
is_lower = (a[row * ncol + col] <= val);
is_next_higher = (a[row * ncol + col + 1] > val);
}
if (is_lower && is_next_higher) {
// we found the right spot
return 0;
} else if (is_lower) {
// answer is on the right side
return 1;
} else {
// answer is on the left side
return -1;
}
}
template <typename scalar_t>
__device__
int binary_search(scalar_t *a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
{
/* Look for the value `val` within row `row` of matrix `a`, which
has `ncol` columns.
the `a` matrix is assumed sorted in increasing order, row-wise
Returns
* -1 if `val` is smaller than the smallest value found within that row of `a`
* `ncol` - 1 if `val` is larger than the largest element of that row of `a`
* Otherwise, return the column index `res` such that:
- a[row, col] < val <= a[row, col+1]. (if side_left), or
- a[row, col] < val <= a[row, col+1] (if not side_left).
*/
//start with left at 0 and right at number of columns of a
int64_t right = ncol;
int64_t left = 0;
while (right >= left) {
// take the midpoint of current left and right cursors
int64_t mid = left + (right-left)/2;
// check the relative position of val: are we good here ?
int rel_pos = eval(val, a, row, mid, ncol, side_left);
// we found the point
if(rel_pos == 0) {
return mid;
} else if (rel_pos > 0) {
if (mid==ncol-1){return ncol-1;}
// the answer is on the right side
left = mid;
} else {
if (mid==0){return -1;}
right = mid;
}
}
return -1;
}
template <typename scalar_t>
__global__
void searchsorted_kernel(
int64_t *res,
scalar_t *a,
scalar_t *v,
int64_t nrow_res, int64_t nrow_a, int64_t nrow_v, int64_t ncol_a, int64_t ncol_v, bool side_left)
{
// get current row and column
int64_t row = blockIdx.y*blockDim.y+threadIdx.y;
int64_t col = blockIdx.x*blockDim.x+threadIdx.x;
// check whether we are outside the bounds of what needs be computed.
if ((row >= nrow_res) || (col >= ncol_v)) {
return;}
// get the value to look for
int64_t row_in_v = (nrow_v==1) ? 0: row;
int64_t row_in_a = (nrow_a==1) ? 0: row;
int64_t idx_in_v = row_in_v*ncol_v+col;
int64_t idx_in_res = row*ncol_v+col;
// apply binary search
res[idx_in_res] = binary_search(a, row_in_a, v[idx_in_v], ncol_a, side_left)+1;
}
void searchsorted_cuda(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left){
// Get the dimensions
auto nrow_a = a.size(/*dim=*/0);
auto nrow_v = v.size(/*dim=*/0);
auto ncol_a = a.size(/*dim=*/1);
auto ncol_v = v.size(/*dim=*/1);
auto nrow_res = fmax(double(nrow_a), double(nrow_v));
// prepare the kernel configuration
dim3 threads(ncol_v, nrow_res);
dim3 blocks(1, 1);
if (nrow_res*ncol_v > 1024){
threads.x = int(fmin(double(1024), double(ncol_v)));
threads.y = floor(1024/threads.x);
blocks.x = ceil(double(ncol_v)/double(threads.x));
blocks.y = ceil(double(nrow_res)/double(threads.y));
}
AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cuda", ([&] {
searchsorted_kernel<scalar_t><<<blocks, threads>>>(
res.data<int64_t>(),
a.data<scalar_t>(),
v.data<scalar_t>(),
nrow_res, nrow_a, nrow_v, ncol_a, ncol_v, side_left);
}));
}
#ifndef _SEARCHSORTED_CUDA_KERNEL
#define _SEARCHSORTED_CUDA_KERNEL
#include <torch/extension.h>
void searchsorted_cuda(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left);
#endif
#include "searchsorted_cuda_wrapper.h"
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
void searchsorted_cuda_wrapper(at::Tensor a, at::Tensor v, at::Tensor res, bool side_left)
{
CHECK_INPUT(a);
CHECK_INPUT(v);
CHECK_INPUT(res);
searchsorted_cuda(a, v, res, side_left);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("searchsorted_cuda_wrapper", &searchsorted_cuda_wrapper, "searchsorted (CUDA)");
}
#ifndef _SEARCHSORTED_CUDA_WRAPPER
#define _SEARCHSORTED_CUDA_WRAPPER
#include <torch/extension.h>
#include "searchsorted_cuda_kernel.h"
void searchsorted_cuda_wrapper(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left);
#endif
from .searchsorted import searchsorted
from .utils import numpy_searchsorted
from typing import Optional
import torch
# trying to import the CPU searchsorted
SEARCHSORTED_CPU_AVAILABLE = True
try:
from torchsearchsorted.cpu import searchsorted_cpu_wrapper
except ImportError:
SEARCHSORTED_CPU_AVAILABLE = False
# trying to import the CUDA searchsorted
SEARCHSORTED_GPU_AVAILABLE = True
try:
from torchsearchsorted.cuda import searchsorted_cuda_wrapper
except ImportError:
SEARCHSORTED_GPU_AVAILABLE = False
def searchsorted(a: torch.Tensor, v: torch.Tensor,
out: Optional[torch.LongTensor] = None,
side='left') -> torch.LongTensor:
assert len(a.shape) == 2, "input `a` must be 2-D."
assert len(v.shape) == 2, "input `v` mus(t be 2-D."
assert (a.shape[0] == v.shape[0]
or a.shape[0] == 1
or v.shape[0] == 1), ("`a` and `v` must have the same number of "
"rows or one of them must have only one ")
assert a.device == v.device, '`a` and `v` must be on the same device'
result_shape = (max(a.shape[0], v.shape[0]), v.shape[1])
if out is not None:
assert out.device == a.device, "`out` must be on the same device as `a`"
assert out.dtype == torch.long, "out.dtype must be torch.long"
assert out.shape == result_shape, ("If the output tensor is provided, "
"its shape must be correct.")
else:
out = torch.empty(result_shape, device=v.device, dtype=torch.long)
if a.is_cuda and not SEARCHSORTED_GPU_AVAILABLE:
raise Exception('torchsearchsorted on CUDA device is asked, but it seems '
'that it is not available. Please install it')
if not a.is_cuda and not SEARCHSORTED_CPU_AVAILABLE:
raise Exception('torchsearchsorted on CPU is not available. '
'Please install it.')
left_side = 1 if side=='left' else 0
if a.is_cuda:
searchsorted_cuda_wrapper(a, v, out, left_side)
else:
searchsorted_cpu_wrapper(a, v, out, left_side)
return out
import numpy as np
def numpy_searchsorted(a: np.ndarray, v: np.ndarray, side='left'):
"""Numpy version of searchsorted that works batch-wise on pytorch tensors
"""
nrows_a = a.shape[0]
(nrows_v, ncols_v) = v.shape
nrows_out = max(nrows_a, nrows_v)
out = np.empty((nrows_out, ncols_v), dtype=np.long)
def sel(data, row):
return data[0] if data.shape[0] == 1 else data[row]
for row in range(nrows_out):
out[row] = np.searchsorted(sel(a, row), sel(v, row), side=side)
return out
import pytest
import torch
devices = {'cpu': torch.device('cpu')}
if torch.cuda.is_available():
devices['cuda'] = torch.device('cuda:0')
@pytest.fixture(params=devices.values(), ids=devices.keys())
def device(request):
return request.param
import pytest
import torch
import numpy as np
from torchsearchsorted import searchsorted, numpy_searchsorted
from itertools import product, repeat
def test_searchsorted_output_dtype(device):
B = 100
A = 50
V = 12
a = torch.sort(torch.rand(B, V, device=device), dim=1)[0]
v = torch.rand(B, A, device=device)
out = searchsorted(a, v)
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy())
assert out.dtype == torch.long
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
out = torch.empty(v.shape, dtype=torch.long, device=device)
searchsorted(a, v, out)
assert out.dtype == torch.long
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
Ba_val = [1, 100, 200]
Bv_val = [1, 100, 200]
A_val = [1, 50, 500]
V_val = [1, 12, 120]
side_val = ['left', 'right']
nrepeat = 100
@pytest.mark.parametrize('Ba,Bv,A,V,side', product(Ba_val, Bv_val, A_val, V_val, side_val))
def test_searchsorted_correct(Ba, Bv, A, V, side, device):
if Ba > 1 and Bv > 1 and Ba != Bv:
return
for test in range(nrepeat):
a = torch.sort(torch.rand(Ba, A, device=device), dim=1)[0]
v = torch.rand(Bv, V, device=device)
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(),
side=side)
out = searchsorted(a, v, side=side).cpu().numpy()
np.testing.assert_array_equal(out, out_np)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment