Commit 223fe62d by Yen-Chen Lin

Fix intrinsics problem

parent a1e1d271
...@@ -7,7 +7,6 @@ import time ...@@ -7,7 +7,6 @@ import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange from tqdm import tqdm, trange
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -17,6 +16,7 @@ from run_nerf_helpers import * ...@@ -17,6 +16,7 @@ from run_nerf_helpers import *
from load_llff import load_llff_data from load_llff import load_llff_data
from load_deepvoxels import load_dv_data from load_deepvoxels import load_dv_data
from load_blender import load_blender_data from load_blender import load_blender_data
from load_LINEMOD import load_LINEMOD_data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
...@@ -66,7 +66,7 @@ def batchify_rays(rays_flat, chunk=1024*32, **kwargs): ...@@ -66,7 +66,7 @@ def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
return all_ret return all_ret
def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True, def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
near=0., far=1., near=0., far=1.,
use_viewdirs=False, c2w_staticcam=None, use_viewdirs=False, c2w_staticcam=None,
**kwargs): **kwargs):
...@@ -94,7 +94,7 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True, ...@@ -94,7 +94,7 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
""" """
if c2w is not None: if c2w is not None:
# special case to render full image # special case to render full image
rays_o, rays_d = get_rays(H, W, focal, c2w) rays_o, rays_d = get_rays(H, W, K, c2w)
else: else:
# use provided ray batch # use provided ray batch
rays_o, rays_d = rays rays_o, rays_d = rays
...@@ -104,14 +104,14 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True, ...@@ -104,14 +104,14 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
viewdirs = rays_d viewdirs = rays_d
if c2w_staticcam is not None: if c2w_staticcam is not None:
# special case to visualize effect of viewdirs # special case to visualize effect of viewdirs
rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam) rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
viewdirs = torch.reshape(viewdirs, [-1,3]).float() viewdirs = torch.reshape(viewdirs, [-1,3]).float()
sh = rays_d.shape # [..., 3] sh = rays_d.shape # [..., 3]
if ndc: if ndc:
# for forward facing scenes # for forward facing scenes
rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d) rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
# Create ray batch # Create ray batch
rays_o = torch.reshape(rays_o, [-1,3]).float() rays_o = torch.reshape(rays_o, [-1,3]).float()
...@@ -134,7 +134,7 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True, ...@@ -134,7 +134,7 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
return ret_list + [ret_dict] return ret_list + [ret_dict]
def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
H, W, focal = hwf H, W, focal = hwf
...@@ -151,7 +151,7 @@ def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=N ...@@ -151,7 +151,7 @@ def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=N
for i, c2w in enumerate(tqdm(render_poses)): for i, c2w in enumerate(tqdm(render_poses)):
print(i, time.time() - t) print(i, time.time() - t)
t = time.time() t = time.time()
rgb, disp, acc, _ = render(H, W, focal, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs) rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
rgbs.append(rgb.cpu().numpy()) rgbs.append(rgb.cpu().numpy())
disps.append(disp.cpu().numpy()) disps.append(disp.cpu().numpy())
if i==0: if i==0:
...@@ -537,7 +537,7 @@ def train(): ...@@ -537,7 +537,7 @@ def train():
args = parser.parse_args() args = parser.parse_args()
# Load data # Load data
K = None
if args.dataset_type == 'llff': if args.dataset_type == 'llff':
images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
recenter=True, bd_factor=.75, recenter=True, bd_factor=.75,
...@@ -579,6 +579,17 @@ def train(): ...@@ -579,6 +579,17 @@ def train():
else: else:
images = images[...,:3] images = images[...,:3]
elif args.dataset_type == 'LINEMOD':
images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip)
print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
print(f'[CHECK HERE] near: {near}, far: {far}.')
i_train, i_val, i_test = i_split
if args.white_bkgd:
images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
else:
images = images[...,:3]
elif args.dataset_type == 'deepvoxels': elif args.dataset_type == 'deepvoxels':
images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,
...@@ -601,6 +612,13 @@ def train(): ...@@ -601,6 +612,13 @@ def train():
H, W = int(H), int(W) H, W = int(H), int(W)
hwf = [H, W, focal] hwf = [H, W, focal]
if K is None:
K = np.array([
[focal, 0, 0.5*W],
[0, focal, 0.5*H],
[0, 0, 1]
])
if args.render_test: if args.render_test:
render_poses = np.array(poses[i_test]) render_poses = np.array(poses[i_test])
...@@ -647,7 +665,7 @@ def train(): ...@@ -647,7 +665,7 @@ def train():
os.makedirs(testsavedir, exist_ok=True) os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', render_poses.shape) print('test poses shape', render_poses.shape)
rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
print('Done rendering', testsavedir) print('Done rendering', testsavedir)
imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)
...@@ -659,7 +677,7 @@ def train(): ...@@ -659,7 +677,7 @@ def train():
if use_batching: if use_batching:
# For random ray batching # For random ray batching
print('get rays') print('get rays')
rays = np.stack([get_rays_np(H, W, focal, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3] rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
print('done, concats') print('done, concats')
rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3] rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3] rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
...@@ -673,6 +691,7 @@ def train(): ...@@ -673,6 +691,7 @@ def train():
i_batch = 0 i_batch = 0
# Move training data to GPU # Move training data to GPU
if use_batching:
images = torch.Tensor(images).to(device) images = torch.Tensor(images).to(device)
poses = torch.Tensor(poses).to(device) poses = torch.Tensor(poses).to(device)
if use_batching: if use_batching:
...@@ -710,10 +729,11 @@ def train(): ...@@ -710,10 +729,11 @@ def train():
# Random from one image # Random from one image
img_i = np.random.choice(i_train) img_i = np.random.choice(i_train)
target = images[img_i] target = images[img_i]
target = torch.Tensor(target).to(device)
pose = poses[img_i, :3,:4] pose = poses[img_i, :3,:4]
if N_rand is not None: if N_rand is not None:
rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
if i < args.precrop_iters: if i < args.precrop_iters:
dH = int(H//2 * args.precrop_frac) dH = int(H//2 * args.precrop_frac)
...@@ -737,7 +757,7 @@ def train(): ...@@ -737,7 +757,7 @@ def train():
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
##### Core optimization loop ##### ##### Core optimization loop #####
rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays, rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
verbose=i < 10, retraw=True, verbose=i < 10, retraw=True,
**render_kwargs_train) **render_kwargs_train)
...@@ -782,7 +802,7 @@ def train(): ...@@ -782,7 +802,7 @@ def train():
if i%args.i_video==0 and i > 0: if i%args.i_video==0 and i > 0:
# Turn on testing mode # Turn on testing mode
with torch.no_grad(): with torch.no_grad():
rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test) rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
print('Done, saving', rgbs.shape, disps.shape) print('Done, saving', rgbs.shape, disps.shape)
moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
...@@ -800,7 +820,7 @@ def train(): ...@@ -800,7 +820,7 @@ def train():
os.makedirs(testsavedir, exist_ok=True) os.makedirs(testsavedir, exist_ok=True)
print('test poses shape', poses[i_test].shape) print('test poses shape', poses[i_test].shape)
with torch.no_grad(): with torch.no_grad():
render_path(torch.Tensor(poses[i_test]).to(device), hwf, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
print('Saved test set') print('Saved test set')
......
...@@ -153,11 +153,11 @@ class NeRF(nn.Module): ...@@ -153,11 +153,11 @@ class NeRF(nn.Module):
# Ray helpers # Ray helpers
def get_rays(H, W, focal, c2w): def get_rays(H, W, K, c2w):
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t() i = i.t()
j = j.t() j = j.t()
dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1) dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame # Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays. # Translate camera frame's origin to the world frame. It is the origin of all rays.
...@@ -165,9 +165,9 @@ def get_rays(H, W, focal, c2w): ...@@ -165,9 +165,9 @@ def get_rays(H, W, focal, c2w):
return rays_o, rays_d return rays_o, rays_d
def get_rays_np(H, W, focal, c2w): def get_rays_np(H, W, K, c2w):
i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1) dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame # Rotate ray directions from camera frame to the world frame
rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays. # Translate camera frame's origin to the world frame. It is the origin of all rays.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment