Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
N
nerf-pytorch
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
songxinkai
nerf-pytorch
Commits
223fe62d
Commit
223fe62d
authored
Jun 08, 2021
by
Yen-Chen Lin
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix intrinsics problem
parent
a1e1d271
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
18 deletions
+38
-18
run_nerf.py
+34
-14
run_nerf_helpers.py
+4
-4
No files found.
run_nerf.py
View file @
223fe62d
...
@@ -7,7 +7,6 @@ import time
...
@@ -7,7 +7,6 @@ import time
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.utils.tensorboard
import
SummaryWriter
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
...
@@ -17,6 +16,7 @@ from run_nerf_helpers import *
...
@@ -17,6 +16,7 @@ from run_nerf_helpers import *
from
load_llff
import
load_llff_data
from
load_llff
import
load_llff_data
from
load_deepvoxels
import
load_dv_data
from
load_deepvoxels
import
load_dv_data
from
load_blender
import
load_blender_data
from
load_blender
import
load_blender_data
from
load_LINEMOD
import
load_LINEMOD_data
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
@@ -66,7 +66,7 @@ def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
...
@@ -66,7 +66,7 @@ def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
return
all_ret
return
all_ret
def
render
(
H
,
W
,
focal
,
chunk
=
1024
*
32
,
rays
=
None
,
c2w
=
None
,
ndc
=
True
,
def
render
(
H
,
W
,
K
,
chunk
=
1024
*
32
,
rays
=
None
,
c2w
=
None
,
ndc
=
True
,
near
=
0.
,
far
=
1.
,
near
=
0.
,
far
=
1.
,
use_viewdirs
=
False
,
c2w_staticcam
=
None
,
use_viewdirs
=
False
,
c2w_staticcam
=
None
,
**
kwargs
):
**
kwargs
):
...
@@ -94,7 +94,7 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
...
@@ -94,7 +94,7 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
"""
"""
if
c2w
is
not
None
:
if
c2w
is
not
None
:
# special case to render full image
# special case to render full image
rays_o
,
rays_d
=
get_rays
(
H
,
W
,
focal
,
c2w
)
rays_o
,
rays_d
=
get_rays
(
H
,
W
,
K
,
c2w
)
else
:
else
:
# use provided ray batch
# use provided ray batch
rays_o
,
rays_d
=
rays
rays_o
,
rays_d
=
rays
...
@@ -104,14 +104,14 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
...
@@ -104,14 +104,14 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
viewdirs
=
rays_d
viewdirs
=
rays_d
if
c2w_staticcam
is
not
None
:
if
c2w_staticcam
is
not
None
:
# special case to visualize effect of viewdirs
# special case to visualize effect of viewdirs
rays_o
,
rays_d
=
get_rays
(
H
,
W
,
focal
,
c2w_staticcam
)
rays_o
,
rays_d
=
get_rays
(
H
,
W
,
K
,
c2w_staticcam
)
viewdirs
=
viewdirs
/
torch
.
norm
(
viewdirs
,
dim
=-
1
,
keepdim
=
True
)
viewdirs
=
viewdirs
/
torch
.
norm
(
viewdirs
,
dim
=-
1
,
keepdim
=
True
)
viewdirs
=
torch
.
reshape
(
viewdirs
,
[
-
1
,
3
])
.
float
()
viewdirs
=
torch
.
reshape
(
viewdirs
,
[
-
1
,
3
])
.
float
()
sh
=
rays_d
.
shape
# [..., 3]
sh
=
rays_d
.
shape
# [..., 3]
if
ndc
:
if
ndc
:
# for forward facing scenes
# for forward facing scenes
rays_o
,
rays_d
=
ndc_rays
(
H
,
W
,
focal
,
1.
,
rays_o
,
rays_d
)
rays_o
,
rays_d
=
ndc_rays
(
H
,
W
,
K
[
0
][
0
]
,
1.
,
rays_o
,
rays_d
)
# Create ray batch
# Create ray batch
rays_o
=
torch
.
reshape
(
rays_o
,
[
-
1
,
3
])
.
float
()
rays_o
=
torch
.
reshape
(
rays_o
,
[
-
1
,
3
])
.
float
()
...
@@ -134,7 +134,7 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
...
@@ -134,7 +134,7 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
return
ret_list
+
[
ret_dict
]
return
ret_list
+
[
ret_dict
]
def
render_path
(
render_poses
,
hwf
,
chunk
,
render_kwargs
,
gt_imgs
=
None
,
savedir
=
None
,
render_factor
=
0
):
def
render_path
(
render_poses
,
hwf
,
K
,
chunk
,
render_kwargs
,
gt_imgs
=
None
,
savedir
=
None
,
render_factor
=
0
):
H
,
W
,
focal
=
hwf
H
,
W
,
focal
=
hwf
...
@@ -151,7 +151,7 @@ def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=N
...
@@ -151,7 +151,7 @@ def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=N
for
i
,
c2w
in
enumerate
(
tqdm
(
render_poses
)):
for
i
,
c2w
in
enumerate
(
tqdm
(
render_poses
)):
print
(
i
,
time
.
time
()
-
t
)
print
(
i
,
time
.
time
()
-
t
)
t
=
time
.
time
()
t
=
time
.
time
()
rgb
,
disp
,
acc
,
_
=
render
(
H
,
W
,
focal
,
chunk
=
chunk
,
c2w
=
c2w
[:
3
,:
4
],
**
render_kwargs
)
rgb
,
disp
,
acc
,
_
=
render
(
H
,
W
,
K
,
chunk
=
chunk
,
c2w
=
c2w
[:
3
,:
4
],
**
render_kwargs
)
rgbs
.
append
(
rgb
.
cpu
()
.
numpy
())
rgbs
.
append
(
rgb
.
cpu
()
.
numpy
())
disps
.
append
(
disp
.
cpu
()
.
numpy
())
disps
.
append
(
disp
.
cpu
()
.
numpy
())
if
i
==
0
:
if
i
==
0
:
...
@@ -537,7 +537,7 @@ def train():
...
@@ -537,7 +537,7 @@ def train():
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# Load data
# Load data
K
=
None
if
args
.
dataset_type
==
'llff'
:
if
args
.
dataset_type
==
'llff'
:
images
,
poses
,
bds
,
render_poses
,
i_test
=
load_llff_data
(
args
.
datadir
,
args
.
factor
,
images
,
poses
,
bds
,
render_poses
,
i_test
=
load_llff_data
(
args
.
datadir
,
args
.
factor
,
recenter
=
True
,
bd_factor
=.
75
,
recenter
=
True
,
bd_factor
=.
75
,
...
@@ -579,6 +579,17 @@ def train():
...
@@ -579,6 +579,17 @@ def train():
else
:
else
:
images
=
images
[
...
,:
3
]
images
=
images
[
...
,:
3
]
elif
args
.
dataset_type
==
'LINEMOD'
:
images
,
poses
,
render_poses
,
hwf
,
K
,
i_split
,
near
,
far
=
load_LINEMOD_data
(
args
.
datadir
,
args
.
half_res
,
args
.
testskip
)
print
(
f
'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}'
)
print
(
f
'[CHECK HERE] near: {near}, far: {far}.'
)
i_train
,
i_val
,
i_test
=
i_split
if
args
.
white_bkgd
:
images
=
images
[
...
,:
3
]
*
images
[
...
,
-
1
:]
+
(
1.
-
images
[
...
,
-
1
:])
else
:
images
=
images
[
...
,:
3
]
elif
args
.
dataset_type
==
'deepvoxels'
:
elif
args
.
dataset_type
==
'deepvoxels'
:
images
,
poses
,
render_poses
,
hwf
,
i_split
=
load_dv_data
(
scene
=
args
.
shape
,
images
,
poses
,
render_poses
,
hwf
,
i_split
=
load_dv_data
(
scene
=
args
.
shape
,
...
@@ -601,6 +612,13 @@ def train():
...
@@ -601,6 +612,13 @@ def train():
H
,
W
=
int
(
H
),
int
(
W
)
H
,
W
=
int
(
H
),
int
(
W
)
hwf
=
[
H
,
W
,
focal
]
hwf
=
[
H
,
W
,
focal
]
if
K
is
None
:
K
=
np
.
array
([
[
focal
,
0
,
0.5
*
W
],
[
0
,
focal
,
0.5
*
H
],
[
0
,
0
,
1
]
])
if
args
.
render_test
:
if
args
.
render_test
:
render_poses
=
np
.
array
(
poses
[
i_test
])
render_poses
=
np
.
array
(
poses
[
i_test
])
...
@@ -647,7 +665,7 @@ def train():
...
@@ -647,7 +665,7 @@ def train():
os
.
makedirs
(
testsavedir
,
exist_ok
=
True
)
os
.
makedirs
(
testsavedir
,
exist_ok
=
True
)
print
(
'test poses shape'
,
render_poses
.
shape
)
print
(
'test poses shape'
,
render_poses
.
shape
)
rgbs
,
_
=
render_path
(
render_poses
,
hwf
,
args
.
chunk
,
render_kwargs_test
,
gt_imgs
=
images
,
savedir
=
testsavedir
,
render_factor
=
args
.
render_factor
)
rgbs
,
_
=
render_path
(
render_poses
,
hwf
,
K
,
args
.
chunk
,
render_kwargs_test
,
gt_imgs
=
images
,
savedir
=
testsavedir
,
render_factor
=
args
.
render_factor
)
print
(
'Done rendering'
,
testsavedir
)
print
(
'Done rendering'
,
testsavedir
)
imageio
.
mimwrite
(
os
.
path
.
join
(
testsavedir
,
'video.mp4'
),
to8b
(
rgbs
),
fps
=
30
,
quality
=
8
)
imageio
.
mimwrite
(
os
.
path
.
join
(
testsavedir
,
'video.mp4'
),
to8b
(
rgbs
),
fps
=
30
,
quality
=
8
)
...
@@ -659,7 +677,7 @@ def train():
...
@@ -659,7 +677,7 @@ def train():
if
use_batching
:
if
use_batching
:
# For random ray batching
# For random ray batching
print
(
'get rays'
)
print
(
'get rays'
)
rays
=
np
.
stack
([
get_rays_np
(
H
,
W
,
focal
,
p
)
for
p
in
poses
[:,:
3
,:
4
]],
0
)
# [N, ro+rd, H, W, 3]
rays
=
np
.
stack
([
get_rays_np
(
H
,
W
,
K
,
p
)
for
p
in
poses
[:,:
3
,:
4
]],
0
)
# [N, ro+rd, H, W, 3]
print
(
'done, concats'
)
print
(
'done, concats'
)
rays_rgb
=
np
.
concatenate
([
rays
,
images
[:,
None
]],
1
)
# [N, ro+rd+rgb, H, W, 3]
rays_rgb
=
np
.
concatenate
([
rays
,
images
[:,
None
]],
1
)
# [N, ro+rd+rgb, H, W, 3]
rays_rgb
=
np
.
transpose
(
rays_rgb
,
[
0
,
2
,
3
,
1
,
4
])
# [N, H, W, ro+rd+rgb, 3]
rays_rgb
=
np
.
transpose
(
rays_rgb
,
[
0
,
2
,
3
,
1
,
4
])
# [N, H, W, ro+rd+rgb, 3]
...
@@ -673,6 +691,7 @@ def train():
...
@@ -673,6 +691,7 @@ def train():
i_batch
=
0
i_batch
=
0
# Move training data to GPU
# Move training data to GPU
if
use_batching
:
images
=
torch
.
Tensor
(
images
)
.
to
(
device
)
images
=
torch
.
Tensor
(
images
)
.
to
(
device
)
poses
=
torch
.
Tensor
(
poses
)
.
to
(
device
)
poses
=
torch
.
Tensor
(
poses
)
.
to
(
device
)
if
use_batching
:
if
use_batching
:
...
@@ -710,10 +729,11 @@ def train():
...
@@ -710,10 +729,11 @@ def train():
# Random from one image
# Random from one image
img_i
=
np
.
random
.
choice
(
i_train
)
img_i
=
np
.
random
.
choice
(
i_train
)
target
=
images
[
img_i
]
target
=
images
[
img_i
]
target
=
torch
.
Tensor
(
target
)
.
to
(
device
)
pose
=
poses
[
img_i
,
:
3
,:
4
]
pose
=
poses
[
img_i
,
:
3
,:
4
]
if
N_rand
is
not
None
:
if
N_rand
is
not
None
:
rays_o
,
rays_d
=
get_rays
(
H
,
W
,
focal
,
torch
.
Tensor
(
pose
))
# (H, W, 3), (H, W, 3)
rays_o
,
rays_d
=
get_rays
(
H
,
W
,
K
,
torch
.
Tensor
(
pose
))
# (H, W, 3), (H, W, 3)
if
i
<
args
.
precrop_iters
:
if
i
<
args
.
precrop_iters
:
dH
=
int
(
H
//
2
*
args
.
precrop_frac
)
dH
=
int
(
H
//
2
*
args
.
precrop_frac
)
...
@@ -737,7 +757,7 @@ def train():
...
@@ -737,7 +757,7 @@ def train():
target_s
=
target
[
select_coords
[:,
0
],
select_coords
[:,
1
]]
# (N_rand, 3)
target_s
=
target
[
select_coords
[:,
0
],
select_coords
[:,
1
]]
# (N_rand, 3)
##### Core optimization loop #####
##### Core optimization loop #####
rgb
,
disp
,
acc
,
extras
=
render
(
H
,
W
,
focal
,
chunk
=
args
.
chunk
,
rays
=
batch_rays
,
rgb
,
disp
,
acc
,
extras
=
render
(
H
,
W
,
K
,
chunk
=
args
.
chunk
,
rays
=
batch_rays
,
verbose
=
i
<
10
,
retraw
=
True
,
verbose
=
i
<
10
,
retraw
=
True
,
**
render_kwargs_train
)
**
render_kwargs_train
)
...
@@ -782,7 +802,7 @@ def train():
...
@@ -782,7 +802,7 @@ def train():
if
i
%
args
.
i_video
==
0
and
i
>
0
:
if
i
%
args
.
i_video
==
0
and
i
>
0
:
# Turn on testing mode
# Turn on testing mode
with
torch
.
no_grad
():
with
torch
.
no_grad
():
rgbs
,
disps
=
render_path
(
render_poses
,
hwf
,
args
.
chunk
,
render_kwargs_test
)
rgbs
,
disps
=
render_path
(
render_poses
,
hwf
,
K
,
args
.
chunk
,
render_kwargs_test
)
print
(
'Done, saving'
,
rgbs
.
shape
,
disps
.
shape
)
print
(
'Done, saving'
,
rgbs
.
shape
,
disps
.
shape
)
moviebase
=
os
.
path
.
join
(
basedir
,
expname
,
'{}_spiral_{:06d}_'
.
format
(
expname
,
i
))
moviebase
=
os
.
path
.
join
(
basedir
,
expname
,
'{}_spiral_{:06d}_'
.
format
(
expname
,
i
))
imageio
.
mimwrite
(
moviebase
+
'rgb.mp4'
,
to8b
(
rgbs
),
fps
=
30
,
quality
=
8
)
imageio
.
mimwrite
(
moviebase
+
'rgb.mp4'
,
to8b
(
rgbs
),
fps
=
30
,
quality
=
8
)
...
@@ -800,7 +820,7 @@ def train():
...
@@ -800,7 +820,7 @@ def train():
os
.
makedirs
(
testsavedir
,
exist_ok
=
True
)
os
.
makedirs
(
testsavedir
,
exist_ok
=
True
)
print
(
'test poses shape'
,
poses
[
i_test
]
.
shape
)
print
(
'test poses shape'
,
poses
[
i_test
]
.
shape
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
render_path
(
torch
.
Tensor
(
poses
[
i_test
])
.
to
(
device
),
hwf
,
args
.
chunk
,
render_kwargs_test
,
gt_imgs
=
images
[
i_test
],
savedir
=
testsavedir
)
render_path
(
torch
.
Tensor
(
poses
[
i_test
])
.
to
(
device
),
hwf
,
K
,
args
.
chunk
,
render_kwargs_test
,
gt_imgs
=
images
[
i_test
],
savedir
=
testsavedir
)
print
(
'Saved test set'
)
print
(
'Saved test set'
)
...
...
run_nerf_helpers.py
View file @
223fe62d
...
@@ -153,11 +153,11 @@ class NeRF(nn.Module):
...
@@ -153,11 +153,11 @@ class NeRF(nn.Module):
# Ray helpers
# Ray helpers
def
get_rays
(
H
,
W
,
focal
,
c2w
):
def
get_rays
(
H
,
W
,
K
,
c2w
):
i
,
j
=
torch
.
meshgrid
(
torch
.
linspace
(
0
,
W
-
1
,
W
),
torch
.
linspace
(
0
,
H
-
1
,
H
))
# pytorch's meshgrid has indexing='ij'
i
,
j
=
torch
.
meshgrid
(
torch
.
linspace
(
0
,
W
-
1
,
W
),
torch
.
linspace
(
0
,
H
-
1
,
H
))
# pytorch's meshgrid has indexing='ij'
i
=
i
.
t
()
i
=
i
.
t
()
j
=
j
.
t
()
j
=
j
.
t
()
dirs
=
torch
.
stack
([(
i
-
W
*.
5
)
/
focal
,
-
(
j
-
H
*.
5
)
/
focal
,
-
torch
.
ones_like
(
i
)],
-
1
)
dirs
=
torch
.
stack
([(
i
-
K
[
0
][
2
])
/
K
[
0
][
0
],
-
(
j
-
K
[
1
][
2
])
/
K
[
1
][
1
]
,
-
torch
.
ones_like
(
i
)],
-
1
)
# Rotate ray directions from camera frame to the world frame
# Rotate ray directions from camera frame to the world frame
rays_d
=
torch
.
sum
(
dirs
[
...
,
np
.
newaxis
,
:]
*
c2w
[:
3
,:
3
],
-
1
)
# dot product, equals to: [c2w.dot(dir) for dir in dirs]
rays_d
=
torch
.
sum
(
dirs
[
...
,
np
.
newaxis
,
:]
*
c2w
[:
3
,:
3
],
-
1
)
# dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
# Translate camera frame's origin to the world frame. It is the origin of all rays.
...
@@ -165,9 +165,9 @@ def get_rays(H, W, focal, c2w):
...
@@ -165,9 +165,9 @@ def get_rays(H, W, focal, c2w):
return
rays_o
,
rays_d
return
rays_o
,
rays_d
def
get_rays_np
(
H
,
W
,
focal
,
c2w
):
def
get_rays_np
(
H
,
W
,
K
,
c2w
):
i
,
j
=
np
.
meshgrid
(
np
.
arange
(
W
,
dtype
=
np
.
float32
),
np
.
arange
(
H
,
dtype
=
np
.
float32
),
indexing
=
'xy'
)
i
,
j
=
np
.
meshgrid
(
np
.
arange
(
W
,
dtype
=
np
.
float32
),
np
.
arange
(
H
,
dtype
=
np
.
float32
),
indexing
=
'xy'
)
dirs
=
np
.
stack
([(
i
-
W
*.
5
)
/
focal
,
-
(
j
-
H
*.
5
)
/
focal
,
-
np
.
ones_like
(
i
)],
-
1
)
dirs
=
np
.
stack
([(
i
-
K
[
0
][
2
])
/
K
[
0
][
0
],
-
(
j
-
K
[
1
][
2
])
/
K
[
1
][
1
]
,
-
np
.
ones_like
(
i
)],
-
1
)
# Rotate ray directions from camera frame to the world frame
# Rotate ray directions from camera frame to the world frame
rays_d
=
np
.
sum
(
dirs
[
...
,
np
.
newaxis
,
:]
*
c2w
[:
3
,:
3
],
-
1
)
# dot product, equals to: [c2w.dot(dir) for dir in dirs]
rays_d
=
np
.
sum
(
dirs
[
...
,
np
.
newaxis
,
:]
*
c2w
[:
3
,:
3
],
-
1
)
# dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
# Translate camera frame's origin to the world frame. It is the origin of all rays.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment