Commit bcb67009 by bmild Committed by Yen-Chen Lin

v0.1 release to public

parents
**/.ipynb_checkpoints
**/__pycache__
*.png
*.mp4
*.npy
*.npz
*.dae
data/*
logs/*
\ No newline at end of file
MIT License
Copyright (c) 2020 bmild
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# NeRF-pytorch
[NeRF](http://www.matthewtancik.com/nerf) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are some videos generated by this repository (pre-trained models are provided below):
![](https://user-images.githubusercontent.com/7057863/78472232-cf374a00-7769-11ea-8871-0bc710951839.gif)
![](https://user-images.githubusercontent.com/7057863/78472235-d1010d80-7769-11ea-9be9-51365180e063.gif)
This project is a faithful PyTorch implementation of [NeRF](http://www.matthewtancik.com/nerf) that **reproduces** the results while running **1.3 times faster**. The code is tested to match authors' Tensorflow implementation [here](https://github.com/bmild/nerf) numerically.
## Installation
```
git clone https://github.com/yenchenlin/nerf-pytorch.git
cd nerf-pytorch
pip install -r requirements.txt
cd torchsearchsorted
pip install .
cd ../
```
<details>
<summary> Dependencies (click to expand) </summary>
## Dependencies
- PyTorch 1.4
- matplotlib
- numpy
- imageio
- imageio-ffmpeg
- configargparse
The LLFF data loader requires ImageMagick.
You will also need the [LLFF code](http://github.com/fyusion/llff) (and COLMAP) set up to compute poses if you want to run on your own real data.
</details>
## How To Run?
### Quick Start
Download data for two example datasets: `lego` and `fern`
```
bash download_example_data.sh
```
To train a low-res `lego` NeRF:
```
python run_nerf_torch.py --config configs/config_lego.txt
```
After training for 100k iterations (~4 hours on a single 2080 Ti), you can find the following video at `logs/lego_test/lego_test_spiral_100000_rgb.mp4`.
![](https://user-images.githubusercontent.com/7057863/78473103-9353b300-7770-11ea-98ed-6ba2d877b62c.gif)
---
To train a low-res `fern` NeRF:
```
python run_nerf_torch.py --config configs/config_fern.txt
```
After training for 200k iterations (~8 hours on a single 2080 Ti), you can find the following video at `logs/fern_test/fern_test_spiral_200000_rgb.mp4` and `logs/fern_test/fern_test_spiral_200000_disp.mp4`
![](https://user-images.githubusercontent.com/7057863/78473081-58ea1600-7770-11ea-92ce-2bbf6a3f9add.gif)
---
### More Datasets
To play with other scenes presented in the paper, download the data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). Place the downloaded dataset according to the following directory structure:
```
├── configs
│   ├── ...
│  
├── data
│   ├── nerf_llff_data
│   │   └── fern
│   │  └── flower # downloaded llff dataset
│   │  └── horns # downloaded llff dataset
| | └── ...
| ├── nerf_synthetic
| | └── lego
| | └── ship # downloaded synthetic dataset
| | └── ...
```
---
To train NeRF on different datasets:
```
python run_nerf_torch.py --config configs/config_{DATASET}.txt
```
replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.
---
To test NeRF trained on different datasets:
```
python run_nerf_torch.py --config configs/config_{DATASET}.txt --render_only
```
replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.
### Pre-trained Models
You can download the pre-trained models [here](https://drive.google.com/drive/folders/1jIr8dkvefrQmv737fFm2isiT6tqpbTbv?usp=sharing). Place the downloaded directory in `./logs` in order to test it later. See the following directory structure for an example:
```
├── logs
│   ├── fern_test
│   ├── flower_test # downloaded logs
│ ├── trex_test # downloaded logs
```
### Reproducibility
Tests that ensure the results of all functions and training loop match the official implentation are contained in a different branch `reproduce`. One can check it out and run the tests:
```
git checkout reproduce
py.test
```
## Method
[NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://tancik.com/nerf)
[Ben Mildenhall](https://people.eecs.berkeley.edu/~bmild/)\*<sup>1</sup>,
[Pratul P. Srinivasan](https://people.eecs.berkeley.edu/~pratul/)\*<sup>1</sup>,
[Matthew Tancik](http://tancik.com/)\*<sup>1</sup>,
[Jonathan T. Barron](http://jonbarron.info/)<sup>2</sup>,
[Ravi Ramamoorthi](http://cseweb.ucsd.edu/~ravir/)<sup>3</sup>,
[Ren Ng](https://www2.eecs.berkeley.edu/Faculty/Homepages/yirenng.html)<sup>1</sup> <br>
<sup>1</sup>UC Berkeley, <sup>2</sup>Google Research, <sup>3</sup>UC San Diego
\*denotes equal contribution
<img src='imgs/pipeline.jpg'/>
> A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views
## Citation
Kudos to the authors for their amazing results:
```
@misc{mildenhall2020nerf,
title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
year={2020},
eprint={2003.08934},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
However, if you find this implementation or pre-trained models helpful, please consider to cite:
```
@misc{lin2020nerfpytorch,
title={NeRF-pytorch},
author={Yen-Chen, Lin},
howpublished={\url{https://github.com/yenchenlin/nerf-pytorch/}},
year={2020}
}
```
expname = fern_test
basedir = ./logs
datadir = ./data/nerf_llff_data/fern
dataset_type = llff
factor = 8
llffhold = 8
N_rand = 1024
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1e0
expname = flower_test
basedir = ./logs
datadir = ./data/nerf_llff_data/flower
dataset_type = llff
factor = 8
llffhold = 8
N_rand = 1024
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1e0
expname = fortress_test
basedir = ./logs
datadir = ./data/nerf_llff_data/fortress
dataset_type = llff
factor = 8
llffhold = 8
N_rand = 1024
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1e0
expname = horns_test
basedir = ./logs
datadir = ./data/nerf_llff_data/horns
dataset_type = llff
factor = 8
llffhold = 8
N_rand = 1024
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1e0
expname = lego_test
basedir = ./logs
datadir = ./data/nerf_synthetic/lego
dataset_type = blender
half_res = True
N_samples = 64
N_importance = 64
use_viewdirs = True
white_bkgd = True
N_rand = 1024
\ No newline at end of file
expname = trex_test
basedir = ./logs
datadir = ./data/nerf_llff_data/trex
dataset_type = llff
factor = 8
llffhold = 8
N_rand = 1024
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1e0
wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz
mkdir -p data
cd data
wget https://people.eecs.berkeley.edu/~bmild/nerf/nerf_example_data.zip
unzip nerf_example_data.zip
cd ..
import os
import torch
import numpy as np
import imageio
import json
import torch.nn.functional as F
import cv2
trans_t = lambda t : torch.Tensor([
[1,0,0,0],
[0,1,0,0],
[0,0,1,t],
[0,0,0,1]]).float()
rot_phi = lambda phi : torch.Tensor([
[1,0,0,0],
[0,np.cos(phi),-np.sin(phi),0],
[0,np.sin(phi), np.cos(phi),0],
[0,0,0,1]]).float()
rot_theta = lambda th : torch.Tensor([
[np.cos(th),0,-np.sin(th),0],
[0,1,0,0],
[np.sin(th),0, np.cos(th),0],
[0,0,0,1]]).float()
def pose_spherical(theta, phi, radius):
c2w = trans_t(radius)
c2w = rot_phi(phi/180.*np.pi) @ c2w
c2w = rot_theta(theta/180.*np.pi) @ c2w
c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
return c2w
def load_blender_data(basedir, half_res=False, testskip=1):
splits = ['train', 'val', 'test']
metas = {}
for s in splits:
with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
metas[s] = json.load(fp)
all_imgs = []
all_poses = []
counts = [0]
for s in splits:
meta = metas[s]
imgs = []
poses = []
if s=='train' or testskip==0:
skip = 1
else:
skip = testskip
for frame in meta['frames'][::skip]:
fname = os.path.join(basedir, frame['file_path'] + '.png')
imgs.append(imageio.imread(fname))
poses.append(np.array(frame['transform_matrix']))
imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
poses = np.array(poses).astype(np.float32)
counts.append(counts[-1] + imgs.shape[0])
all_imgs.append(imgs)
all_poses.append(poses)
i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
imgs = np.concatenate(all_imgs, 0)
poses = np.concatenate(all_poses, 0)
H, W = imgs[0].shape[:2]
camera_angle_x = float(meta['camera_angle_x'])
focal = .5 * W / np.tan(.5 * camera_angle_x)
render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
if half_res:
H = H//2
W = W//2
focal = focal/2.
imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
for i, img in enumerate(imgs):
imgs_half_res[i] = cv2.resize(img, (H, W), interpolation=cv2.INTER_AREA)
imgs = imgs_half_res
# imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
return imgs, poses, render_poses, [H, W, focal], i_split
import os
import numpy as np
import imageio
def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8):
def parse_intrinsics(filepath, trgt_sidelength, invert_y=False):
# Get camera intrinsics
with open(filepath, 'r') as file:
f, cx, cy = list(map(float, file.readline().split()))[:3]
grid_barycenter = np.array(list(map(float, file.readline().split())))
near_plane = float(file.readline())
scale = float(file.readline())
height, width = map(float, file.readline().split())
try:
world2cam_poses = int(file.readline())
except ValueError:
world2cam_poses = None
if world2cam_poses is None:
world2cam_poses = False
world2cam_poses = bool(world2cam_poses)
print(cx,cy,f,height,width)
cx = cx / width * trgt_sidelength
cy = cy / height * trgt_sidelength
f = trgt_sidelength / height * f
fx = f
if invert_y:
fy = -f
else:
fy = f
# Build the intrinsic matrices
full_intrinsic = np.array([[fx, 0., cx, 0.],
[0., fy, cy, 0],
[0., 0, 1, 0],
[0, 0, 0, 1]])
return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses
def load_pose(filename):
assert os.path.isfile(filename)
nums = open(filename).read().split()
return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32)
H = 512
W = 512
deepvoxels_base = '{}/train/{}/'.format(basedir, scene)
full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H)
print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses)
focal = full_intrinsic[0,0]
print(H, W, focal)
def dir2poses(posedir):
poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0)
transf = np.array([
[1,0,0,0],
[0,-1,0,0],
[0,0,-1,0],
[0,0,0,1.],
])
poses = poses @ transf
poses = poses[:,:3,:4].astype(np.float32)
return poses
posedir = os.path.join(deepvoxels_base, 'pose')
poses = dir2poses(posedir)
testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene))
testposes = testposes[::testskip]
valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene))
valposes = valposes[::testskip]
imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')]
imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32)
testimgd = '{}/test/{}/rgb'.format(basedir, scene)
imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')]
testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
valimgd = '{}/validation/{}/rgb'.format(basedir, scene)
imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')]
valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
all_imgs = [imgs, valimgs, testimgs]
counts = [0] + [x.shape[0] for x in all_imgs]
counts = np.cumsum(counts)
i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
imgs = np.concatenate(all_imgs, 0)
poses = np.concatenate([poses, valposes, testposes], 0)
render_poses = testposes
print(poses.shape, imgs.shape)
return imgs, poses, render_poses, [H,W,focal], i_split
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
import torch
torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# TODO: remove this dependency
from torchsearchsorted import searchsorted
# Misc
img2mse = lambda x, y : torch.mean((x - y) ** 2)
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
# Positional encoding (section 5.1)
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, i=0):
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
'include_input' : True,
'input_dims' : 3,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'log_sampling' : True,
'periodic_fns' : [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim
# Model
class NeRF(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
"""
"""
super(NeRF, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_views = input_ch_views
self.skips = skips
self.use_viewdirs = use_viewdirs
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
### Implementation according to the paper
# self.views_linears = nn.ModuleList(
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
if use_viewdirs:
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W//2, 3)
else:
self.output_linear = nn.Linear(W, output_ch)
def forward(self, x):
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
h = input_pts
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
h = torch.cat([input_pts, h], -1)
if self.use_viewdirs:
alpha = self.alpha_linear(h)
feature = self.feature_linear(h)
h = torch.cat([feature, input_views], -1)
for i, l in enumerate(self.views_linears):
h = self.views_linears[i](h)
h = F.relu(h)
rgb = self.rgb_linear(h)
outputs = torch.cat([rgb, alpha], -1)
else:
outputs = self.output_linear(h)
return outputs
def load_weights_from_keras(self, weights):
assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
# Load pts_linears
for i in range(self.D):
idx_pts_linears = 2 * i
self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears]))
self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1]))
# Load feature_linear
idx_feature_linear = 2 * self.D
self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear]))
self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1]))
# Load views_linears
idx_views_linears = 2 * self.D + 2
self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears]))
self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1]))
# Load rgb_linear
idx_rbg_linear = 2 * self.D + 4
self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear]))
self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1]))
# Load alpha_linear
idx_alpha_linear = 2 * self.D + 6
self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear]))
self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1]))
# Ray helpers
def get_rays(H, W, focal, c2w):
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t()
j = j.t()
dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3,-1].expand(rays_d.shape)
return rays_o, rays_d
def get_rays_np(H, W, focal, c2w):
i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
return rays_o, rays_d
def ndc_rays(H, W, focal, near, rays_o, rays_d):
# Shift ray origins to near plane
t = -(near + rays_o[...,2]) / rays_d[...,2]
rays_o = rays_o + t[...,None] * rays_d
# Projection
o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
o2 = 1. + 2. * near / rays_o[...,2]
d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
d2 = -2. * near / rays_o[...,2]
rays_o = torch.stack([o0,o1,o2], -1)
rays_d = torch.stack([d0,d1,d2], -1)
return rays_o, rays_d
# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins))
# Take uniform samples
if det:
u = torch.linspace(0., 1., steps=N_samples)
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
new_shape = list(cdf.shape[:-1]) + [N_samples]
if det:
u = np.linspace(0., 1., N_samples)
u = np.broadcast_to(u, new_shape)
else:
u = np.random.rand(*new_shape)
u = torch.Tensor(u)
# Invert CDF
u = u.contiguous()
inds = searchsorted(cdf, u, side='right')
below = torch.max(torch.zeros_like(inds-1), inds-1)
above = torch.min(cdf.shape[-1]-1 * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
# cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
# bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[...,1]-cdf_g[...,0])
denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
t = (u-cdf_g[...,0])/denom
samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
return samples
\ No newline at end of file
# Prerequisites
*.d
# Object files
*.o
*.ko
*.obj
*.elf
# Linker output
*.ilk
*.map
*.exp
# Precompiled Headers
*.gch
*.pch
# Libraries
*.lib
*.a
*.la
*.lo
# Shared objects (inc. Windows DLLs)
*.dll
*.so
*.so.*
*.dylib
# Executables
*.exe
*.out
*.app
*.i*86
*.x86_64
*.hex
# Debug files
*.dSYM/
*.su
*.idb
*.pdb
# Kernel Module Compile Results
*.mod*
*.cmd
.tmp_versions/
modules.order
Module.symvers
Mkfile.old
dkms.conf
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
BSD 3-Clause License
Copyright (c) 2019, Inria (Antoine Liutkus)
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Pytorch Custom CUDA kernel for searchsorted
This repository is an implementation of the searchsorted function to work for pytorch CUDA Tensors. Initially derived from the great [C extension tutorial](https://github.com/chrischoy/pytorch-custom-cuda-tutorial), but totally changed since then because building C extensions is not available anymore on pytorch 1.0.
> Warnings:
> * only works with pytorch > v1.3 and CUDA >= v10.1
> * **NOTE** When using `searchsorted()` for practical applications, tensors need to be contiguous in memory. This can be easily achieved by calling `tensor.contiguous()` on the input tensors. Failing to do so _will_ lead to inconsistent results across applications.
## Description
Implements a function `searchsorted(a, v, out, side)` that works just like the [numpy version](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted) except that `a` and `v` are matrices.
* `a` is of shape either `(1, ncols_a)` or `(nrows, ncols_a)`, and is contiguous in memory (do `a.contiguous()` to ensure this).
* `v` is of shape either `(1, ncols_v)` or `(nrows, ncols_v)`, and is contiguous in memory (do `v.contiguous()` to ensure this).
* `out` is either `None` or of shape `(nrows, ncols_v)`. If provided and of the right shape, the result is put there. This is to avoid costly memory allocations if the user already did it. If provided, `out` should be contiguous in memory too (do `out.contiguous()` to ensure this).
* `side` is either "left" or "right". See the [numpy doc](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted). Please not that the current implementation *does not correctly handle this parameter*. Help welcome to improve the speed of [this PR](https://github.com/aliutkus/torchsearchsorted/pull/7)
the output is of size as `(nrows, ncols_v)`. If all input tensors are on GPU, a cuda version will be called. Otherwise, it will be on CPU.
**Disclaimers**
* This function has not been heavily tested. Use at your own risks
* When `a` is not sorted, the results vary from numpy's version. But I decided not to care about this because the function should not be called in this case.
* In some cases, the results vary from numpy's version. However, as far as I could see, this only happens when values are equal, which means we actually don't care about the order in which this value is added. I decided not to care about this also.
* vectors have to be contiguous for torchsearchsorted to give consistant results. use `.contiguous()` on all tensor arguments before calling
## Installation
Just `pip install .`, in the root folder of this repo. This will compile
and install the torchsearchsorted module.
be careful that sometimes, `nvcc` needs versions of `gcc` and `g++` that are older than those found by default on the system. If so, just create symbolic links to the right versions in your cuda/bin folder (where `nvcc` is)
For instance, on my machine, I had `gcc` and `g++` v9 installed, but `nvcc` required v8.
So I had to do:
> sudo apt-get install g++-8 gcc-8
> sudo ln -s /usr/bin/gcc-8 /usr/local/cuda-10.1/bin/gcc
> sudo ln -s /usr/bin/g++-8 /usr/local/cuda-10.1/bin/g++
be careful that you need pytorch to be installed on your system. The code was tested on pytorch v1.3
## Usage
Just import the torchsearchsorted package after installation. I typically do:
```
from torchsearchsorted import searchsorted
```
## Testing
Under the `examples` subfolder, you may:
1. try `python test.py` with `torch` available.
```
Looking for 50000x1000 values in 50000x300 entries
NUMPY: searchsorted in 4851.592ms
CPU: searchsorted in 4805.432ms
difference between CPU and NUMPY: 0.000
GPU: searchsorted in 1.055ms
difference between GPU and NUMPY: 0.000
Looking for 50000x1000 values in 50000x300 entries
NUMPY: searchsorted in 4333.964ms
CPU: searchsorted in 4753.958ms
difference between CPU and NUMPY: 0.000
GPU: searchsorted in 0.391ms
difference between GPU and NUMPY: 0.000
```
The first run comprises the time of allocation, while the second one does not.
2. You may also use the nice `benchmark.py` code written by [@baldassarreFe](https://github.com/baldassarreFe), that tests `searchsorted` on many runs:
```
Benchmark searchsorted:
- a [5000 x 300]
- v [5000 x 100]
- reporting fastest time of 20 runs
- each run executes searchsorted 100 times
Numpy: 4.6302046799100935
CPU: 5.041533078998327
CUDA: 0.0007955809123814106
```
import timeit
import torch
import numpy as np
from torchsearchsorted import searchsorted, numpy_searchsorted
B = 5_000
A = 300
V = 100
repeats = 20
number = 100
print(
f'Benchmark searchsorted:',
f'- a [{B} x {A}]',
f'- v [{B} x {V}]',
f'- reporting fastest time of {repeats} runs',
f'- each run executes searchsorted {number} times',
sep='\n',
end='\n\n'
)
def get_arrays():
a = np.sort(np.random.randn(B, A), axis=1)
v = np.random.randn(B, V)
out = np.empty_like(v, dtype=np.long)
return a, v, out
def get_tensors(device):
a = torch.sort(torch.randn(B, A, device=device), dim=1)[0]
v = torch.randn(B, V, device=device)
out = torch.empty(B, V, device=device, dtype=torch.long)
if torch.cuda.is_available():
torch.cuda.synchronize()
return a, v, out
def searchsorted_synchronized(a,v,out=None,side='left'):
out = searchsorted(a,v,out,side)
torch.cuda.synchronize()
return out
numpy = timeit.repeat(
stmt="numpy_searchsorted(a, v, side='left')",
setup="a, v, out = get_arrays()",
globals=globals(),
repeat=repeats,
number=number
)
print('Numpy: ', min(numpy), sep='\t')
cpu = timeit.repeat(
stmt="searchsorted(a, v, out, side='left')",
setup="a, v, out = get_tensors(device='cpu')",
globals=globals(),
repeat=repeats,
number=number
)
print('CPU: ', min(cpu), sep='\t')
if torch.cuda.is_available():
gpu = timeit.repeat(
stmt="searchsorted_synchronized(a, v, out, side='left')",
setup="a, v, out = get_tensors(device='cuda')",
globals=globals(),
repeat=repeats,
number=number
)
print('CUDA: ', min(gpu), sep='\t')
import torch
from torchsearchsorted import searchsorted, numpy_searchsorted
import time
if __name__ == '__main__':
# defining the number of tests
ntests = 2
# defining the problem dimensions
nrows_a = 50000
nrows_v = 50000
nsorted_values = 300
nvalues = 1000
# defines the variables. The first run will comprise allocation, the
# further ones will not
test_GPU = None
test_CPU = None
for ntest in range(ntests):
print("\nLooking for %dx%d values in %dx%d entries" % (nrows_v, nvalues,
nrows_a,
nsorted_values))
side = 'right'
# generate a matrix with sorted rows
a = torch.randn(nrows_a, nsorted_values, device='cpu')
a = torch.sort(a, dim=1)[0]
# generate a matrix of values to searchsort
v = torch.randn(nrows_v, nvalues, device='cpu')
# a = torch.tensor([[0., 1.]])
# v = torch.tensor([[1.]])
t0 = time.time()
test_NP = torch.tensor(numpy_searchsorted(a, v, side))
print('NUMPY: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
t0 = time.time()
test_CPU = searchsorted(a, v, test_CPU, side)
print('CPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
# compute the difference between both
error_CPU = torch.norm(test_NP.double()
- test_CPU.double()).numpy()
if error_CPU:
import ipdb; ipdb.set_trace()
print(' difference between CPU and NUMPY: %0.3f' % error_CPU)
if not torch.cuda.is_available():
print('CUDA is not available on this machine, cannot go further.')
continue
else:
# now do the CPU
a = a.to('cuda')
v = v.to('cuda')
torch.cuda.synchronize()
# launch searchsorted on those
t0 = time.time()
test_GPU = searchsorted(a, v, test_GPU, side)
torch.cuda.synchronize()
print('GPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
# compute the difference between both
error_CUDA = torch.norm(test_NP.to('cuda').double()
- test_GPU.double()).cpu().numpy()
print(' difference between GPU and NUMPY: %0.3f' % error_CUDA)
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME
from torch.utils.cpp_extension import CppExtension, CUDAExtension
# In any case, include the CPU version
modules = [
CppExtension('torchsearchsorted.cpu',
['src/cpu/searchsorted_cpu_wrapper.cpp']),
]
# If nvcc is available, add the CUDA extension
if CUDA_HOME:
modules.append(
CUDAExtension('torchsearchsorted.cuda',
['src/cuda/searchsorted_cuda_wrapper.cpp',
'src/cuda/searchsorted_cuda_kernel.cu'])
)
tests_require = [
'pytest',
]
# Now proceed to setup
setup(
name='torchsearchsorted',
version='1.1',
description='A searchsorted implementation for pytorch',
keywords='searchsorted',
author='Antoine Liutkus',
author_email='antoine.liutkus@inria.fr',
packages=find_packages(where='src'),
package_dir={"": "src"},
ext_modules=modules,
tests_require=tests_require,
extras_require={
'test': tests_require,
},
cmdclass={
'build_ext': BuildExtension
}
)
#include "searchsorted_cpu_wrapper.h"
#include <stdio.h>
template<typename scalar_t>
int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
{
/* Evaluates whether a[row,col] < val <= a[row, col+1]*/
if (col == ncol - 1)
{
// special case: we are on the right border
if (a[row * ncol + col] <= val){
return 1;}
else {
return -1;}
}
bool is_lower;
bool is_next_higher;
if (side_left) {
// a[row, col] < v <= a[row, col+1]
is_lower = (a[row * ncol + col] < val);
is_next_higher = (a[row*ncol + col + 1] >= val);
} else {
// a[row, col] <= v < a[row, col+1]
is_lower = (a[row * ncol + col] <= val);
is_next_higher = (a[row * ncol + col + 1] > val);
}
if (is_lower && is_next_higher) {
// we found the right spot
return 0;
} else if (is_lower) {
// answer is on the right side
return 1;
} else {
// answer is on the left side
return -1;
}
}
template<typename scalar_t>
int64_t binary_search(scalar_t*a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
{
/* Look for the value `val` within row `row` of matrix `a`, which
has `ncol` columns.
the `a` matrix is assumed sorted in increasing order, row-wise
returns:
* -1 if `val` is smaller than the smallest value found within that row of `a`
* `ncol` - 1 if `val` is larger than the largest element of that row of `a`
* Otherwise, return the column index `res` such that:
- a[row, col] < val <= a[row, col+1]. (if side_left), or
- a[row, col] < val <= a[row, col+1] (if not side_left).
*/
//start with left at 0 and right at number of columns of a
int64_t right = ncol;
int64_t left = 0;
while (right >= left) {
// take the midpoint of current left and right cursors
int64_t mid = left + (right-left)/2;
// check the relative position of val: are we good here ?
int rel_pos = eval(val, a, row, mid, ncol, side_left);
// we found the point
if(rel_pos == 0) {
return mid;
} else if (rel_pos > 0) {
if (mid==ncol-1){return ncol-1;}
// the answer is on the right side
left = mid;
} else {
if (mid==0){return -1;}
right = mid;
}
}
return -1;
}
void searchsorted_cpu_wrapper(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left)
{
// Get the dimensions
auto nrow_a = a.size(/*dim=*/0);
auto ncol_a = a.size(/*dim=*/1);
auto nrow_v = v.size(/*dim=*/0);
auto ncol_v = v.size(/*dim=*/1);
auto nrow_res = fmax(nrow_a, nrow_v);
//auto acc_v = v.accessor<float, 2>();
//auto acc_res = res.accessor<float, 2>();
AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cpu", [&] {
scalar_t* a_data = a.data_ptr<scalar_t>();
scalar_t* v_data = v.data_ptr<scalar_t>();
int64_t* res_data = res.data<int64_t>();
for (int64_t row = 0; row < nrow_res; row++)
{
for (int64_t col = 0; col < ncol_v; col++)
{
// get the value to look for
int64_t row_in_v = (nrow_v == 1) ? 0 : row;
int64_t row_in_a = (nrow_a == 1) ? 0 : row;
int64_t idx_in_v = row_in_v * ncol_v + col;
int64_t idx_in_res = row * ncol_v + col;
// apply binary search
res_data[idx_in_res] = (binary_search(a_data, row_in_a, v_data[idx_in_v], ncol_a, side_left) + 1);
}
}
});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("searchsorted_cpu_wrapper", &searchsorted_cpu_wrapper, "searchsorted (CPU)");
}
#ifndef _SEARCHSORTED_CPU
#define _SEARCHSORTED_CPU
#include <torch/extension.h>
void searchsorted_cpu_wrapper(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left);
#endif
\ No newline at end of file
#include "searchsorted_cuda_kernel.h"
template <typename scalar_t>
__device__
int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
{
/* Evaluates whether a[row,col] < val <= a[row, col+1]*/
if (col == ncol - 1)
{
// special case: we are on the right border
if (a[row * ncol + col] <= val){
return 1;}
else {
return -1;}
}
bool is_lower;
bool is_next_higher;
if (side_left) {
// a[row, col] < v <= a[row, col+1]
is_lower = (a[row * ncol + col] < val);
is_next_higher = (a[row*ncol + col + 1] >= val);
} else {
// a[row, col] <= v < a[row, col+1]
is_lower = (a[row * ncol + col] <= val);
is_next_higher = (a[row * ncol + col + 1] > val);
}
if (is_lower && is_next_higher) {
// we found the right spot
return 0;
} else if (is_lower) {
// answer is on the right side
return 1;
} else {
// answer is on the left side
return -1;
}
}
template <typename scalar_t>
__device__
int binary_search(scalar_t *a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
{
/* Look for the value `val` within row `row` of matrix `a`, which
has `ncol` columns.
the `a` matrix is assumed sorted in increasing order, row-wise
Returns
* -1 if `val` is smaller than the smallest value found within that row of `a`
* `ncol` - 1 if `val` is larger than the largest element of that row of `a`
* Otherwise, return the column index `res` such that:
- a[row, col] < val <= a[row, col+1]. (if side_left), or
- a[row, col] < val <= a[row, col+1] (if not side_left).
*/
//start with left at 0 and right at number of columns of a
int64_t right = ncol;
int64_t left = 0;
while (right >= left) {
// take the midpoint of current left and right cursors
int64_t mid = left + (right-left)/2;
// check the relative position of val: are we good here ?
int rel_pos = eval(val, a, row, mid, ncol, side_left);
// we found the point
if(rel_pos == 0) {
return mid;
} else if (rel_pos > 0) {
if (mid==ncol-1){return ncol-1;}
// the answer is on the right side
left = mid;
} else {
if (mid==0){return -1;}
right = mid;
}
}
return -1;
}
template <typename scalar_t>
__global__
void searchsorted_kernel(
int64_t *res,
scalar_t *a,
scalar_t *v,
int64_t nrow_res, int64_t nrow_a, int64_t nrow_v, int64_t ncol_a, int64_t ncol_v, bool side_left)
{
// get current row and column
int64_t row = blockIdx.y*blockDim.y+threadIdx.y;
int64_t col = blockIdx.x*blockDim.x+threadIdx.x;
// check whether we are outside the bounds of what needs be computed.
if ((row >= nrow_res) || (col >= ncol_v)) {
return;}
// get the value to look for
int64_t row_in_v = (nrow_v==1) ? 0: row;
int64_t row_in_a = (nrow_a==1) ? 0: row;
int64_t idx_in_v = row_in_v*ncol_v+col;
int64_t idx_in_res = row*ncol_v+col;
// apply binary search
res[idx_in_res] = binary_search(a, row_in_a, v[idx_in_v], ncol_a, side_left)+1;
}
void searchsorted_cuda(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left){
// Get the dimensions
auto nrow_a = a.size(/*dim=*/0);
auto nrow_v = v.size(/*dim=*/0);
auto ncol_a = a.size(/*dim=*/1);
auto ncol_v = v.size(/*dim=*/1);
auto nrow_res = fmax(double(nrow_a), double(nrow_v));
// prepare the kernel configuration
dim3 threads(ncol_v, nrow_res);
dim3 blocks(1, 1);
if (nrow_res*ncol_v > 1024){
threads.x = int(fmin(double(1024), double(ncol_v)));
threads.y = floor(1024/threads.x);
blocks.x = ceil(double(ncol_v)/double(threads.x));
blocks.y = ceil(double(nrow_res)/double(threads.y));
}
AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cuda", ([&] {
searchsorted_kernel<scalar_t><<<blocks, threads>>>(
res.data<int64_t>(),
a.data<scalar_t>(),
v.data<scalar_t>(),
nrow_res, nrow_a, nrow_v, ncol_a, ncol_v, side_left);
}));
}
#ifndef _SEARCHSORTED_CUDA_KERNEL
#define _SEARCHSORTED_CUDA_KERNEL
#include <torch/extension.h>
void searchsorted_cuda(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left);
#endif
#include "searchsorted_cuda_wrapper.h"
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
void searchsorted_cuda_wrapper(at::Tensor a, at::Tensor v, at::Tensor res, bool side_left)
{
CHECK_INPUT(a);
CHECK_INPUT(v);
CHECK_INPUT(res);
searchsorted_cuda(a, v, res, side_left);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("searchsorted_cuda_wrapper", &searchsorted_cuda_wrapper, "searchsorted (CUDA)");
}
#ifndef _SEARCHSORTED_CUDA_WRAPPER
#define _SEARCHSORTED_CUDA_WRAPPER
#include <torch/extension.h>
#include "searchsorted_cuda_kernel.h"
void searchsorted_cuda_wrapper(
at::Tensor a,
at::Tensor v,
at::Tensor res,
bool side_left);
#endif
from .searchsorted import searchsorted
from .utils import numpy_searchsorted
from typing import Optional
import torch
# trying to import the CPU searchsorted
SEARCHSORTED_CPU_AVAILABLE = True
try:
from torchsearchsorted.cpu import searchsorted_cpu_wrapper
except ImportError:
SEARCHSORTED_CPU_AVAILABLE = False
# trying to import the CUDA searchsorted
SEARCHSORTED_GPU_AVAILABLE = True
try:
from torchsearchsorted.cuda import searchsorted_cuda_wrapper
except ImportError:
SEARCHSORTED_GPU_AVAILABLE = False
def searchsorted(a: torch.Tensor, v: torch.Tensor,
out: Optional[torch.LongTensor] = None,
side='left') -> torch.LongTensor:
assert len(a.shape) == 2, "input `a` must be 2-D."
assert len(v.shape) == 2, "input `v` mus(t be 2-D."
assert (a.shape[0] == v.shape[0]
or a.shape[0] == 1
or v.shape[0] == 1), ("`a` and `v` must have the same number of "
"rows or one of them must have only one ")
assert a.device == v.device, '`a` and `v` must be on the same device'
result_shape = (max(a.shape[0], v.shape[0]), v.shape[1])
if out is not None:
assert out.device == a.device, "`out` must be on the same device as `a`"
assert out.dtype == torch.long, "out.dtype must be torch.long"
assert out.shape == result_shape, ("If the output tensor is provided, "
"its shape must be correct.")
else:
out = torch.empty(result_shape, device=v.device, dtype=torch.long)
if a.is_cuda and not SEARCHSORTED_GPU_AVAILABLE:
raise Exception('torchsearchsorted on CUDA device is asked, but it seems '
'that it is not available. Please install it')
if not a.is_cuda and not SEARCHSORTED_CPU_AVAILABLE:
raise Exception('torchsearchsorted on CPU is not available. '
'Please install it.')
left_side = 1 if side=='left' else 0
if a.is_cuda:
searchsorted_cuda_wrapper(a, v, out, left_side)
else:
searchsorted_cpu_wrapper(a, v, out, left_side)
return out
import numpy as np
def numpy_searchsorted(a: np.ndarray, v: np.ndarray, side='left'):
"""Numpy version of searchsorted that works batch-wise on pytorch tensors
"""
nrows_a = a.shape[0]
(nrows_v, ncols_v) = v.shape
nrows_out = max(nrows_a, nrows_v)
out = np.empty((nrows_out, ncols_v), dtype=np.long)
def sel(data, row):
return data[0] if data.shape[0] == 1 else data[row]
for row in range(nrows_out):
out[row] = np.searchsorted(sel(a, row), sel(v, row), side=side)
return out
import pytest
import torch
devices = {'cpu': torch.device('cpu')}
if torch.cuda.is_available():
devices['cuda'] = torch.device('cuda:0')
@pytest.fixture(params=devices.values(), ids=devices.keys())
def device(request):
return request.param
import pytest
import torch
import numpy as np
from torchsearchsorted import searchsorted, numpy_searchsorted
from itertools import product, repeat
def test_searchsorted_output_dtype(device):
B = 100
A = 50
V = 12
a = torch.sort(torch.rand(B, V, device=device), dim=1)[0]
v = torch.rand(B, A, device=device)
out = searchsorted(a, v)
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy())
assert out.dtype == torch.long
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
out = torch.empty(v.shape, dtype=torch.long, device=device)
searchsorted(a, v, out)
assert out.dtype == torch.long
np.testing.assert_array_equal(out.cpu().numpy(), out_np)
Ba_val = [1, 100, 200]
Bv_val = [1, 100, 200]
A_val = [1, 50, 500]
V_val = [1, 12, 120]
side_val = ['left', 'right']
nrepeat = 100
@pytest.mark.parametrize('Ba,Bv,A,V,side', product(Ba_val, Bv_val, A_val, V_val, side_val))
def test_searchsorted_correct(Ba, Bv, A, V, side, device):
if Ba > 1 and Bv > 1 and Ba != Bv:
return
for test in range(nrepeat):
a = torch.sort(torch.rand(Ba, A, device=device), dim=1)[0]
v = torch.rand(Bv, V, device=device)
out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(),
side=side)
out = searchsorted(a, v, side=side).cpu().numpy()
np.testing.assert_array_equal(out, out_np)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment