Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
N
nerf-pytorch
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
songxinkai
nerf-pytorch
Commits
c3ccc0bd
Commit
c3ccc0bd
authored
Apr 16, 2020
by
Yen-Chen Lin
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add doc string
parent
f61ca730
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
156 additions
and
57 deletions
+156
-57
run_nerf.py
+156
-57
No files found.
run_nerf.py
View file @
c3ccc0bd
...
...
@@ -25,6 +25,8 @@ DEBUG = False
def
batchify
(
fn
,
chunk
):
"""Constructs a version of 'fn' that applies to smaller batches.
"""
if
chunk
is
None
:
return
fn
def
ret
(
inputs
):
...
...
@@ -33,7 +35,8 @@ def batchify(fn, chunk):
def
run_network
(
inputs
,
viewdirs
,
fn
,
embed_fn
,
embeddirs_fn
,
netchunk
=
1024
*
64
):
"""Prepares inputs and applies network 'fn'.
"""
inputs_flat
=
torch
.
reshape
(
inputs
,
[
-
1
,
inputs
.
shape
[
-
1
]])
embedded
=
embed_fn
(
inputs_flat
)
...
...
@@ -49,7 +52,8 @@ def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
def
batchify_rays
(
rays_flat
,
chunk
=
1024
*
32
,
**
kwargs
):
"""Render rays in smaller minibatches to avoid OOM.
"""
all_ret
=
{}
for
i
in
range
(
0
,
rays_flat
.
shape
[
0
],
chunk
):
ret
=
render_rays
(
rays_flat
[
i
:
i
+
chunk
],
**
kwargs
)
...
...
@@ -66,7 +70,28 @@ def render(H, W, focal, chunk=1024*32, rays=None, c2w=None, ndc=True,
near
=
0.
,
far
=
1.
,
use_viewdirs
=
False
,
c2w_staticcam
=
None
,
**
kwargs
):
"""Render rays
Args:
H: int. Height of image in pixels.
W: int. Width of image in pixels.
focal: float. Focal length of pinhole camera.
chunk: int. Maximum number of rays to process simultaneously. Used to
control maximum memory usage. Does not affect final results.
rays: array of shape [2, batch_size, 3]. Ray origin and direction for
each example in batch.
c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
ndc: bool. If True, represent ray origin, direction in NDC coordinates.
near: float or array of shape [batch_size]. Nearest distance for a ray.
far: float or array of shape [batch_size]. Farthest distance for a ray.
use_viewdirs: bool. If True, use viewing direction of a point in space in model.
c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
camera while using other c2w argument for viewing directions.
Returns:
rgb_map: [batch_size, 3]. Predicted RGB values for rays.
disp_map: [batch_size]. Disparity map. Inverse of depth.
acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
extras: dict with everything returned by render_rays().
"""
if
c2w
is
not
None
:
# special case to render full image
rays_o
,
rays_d
=
get_rays
(
H
,
W
,
focal
,
c2w
)
...
...
@@ -151,6 +176,8 @@ def render_path(render_poses, hwf, chunk, render_kwargs, gt_imgs=None, savedir=N
def
create_nerf
(
args
):
"""Instantiate NeRF's MLP model.
"""
embed_fn
,
input_ch
=
get_embedder
(
args
.
multires
,
args
.
i_embed
)
input_ch_views
=
0
...
...
@@ -233,7 +260,17 @@ def create_nerf(args):
def
raw2outputs
(
raw
,
z_vals
,
rays_d
,
raw_noise_std
=
0
,
white_bkgd
=
False
,
pytest
=
False
):
""" A helper function for `render_rays`.
"""Transforms model's predictions to semantically meaningful values.
Args:
raw: [num_rays, num_samples along ray, 4]. Prediction from model.
z_vals: [num_rays, num_samples along ray]. Integration time.
rays_d: [num_rays, 3]. Direction of each ray.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
disp_map: [num_rays]. Disparity map. Inverse of depth map.
acc_map: [num_rays]. Sum of weights along each ray.
weights: [num_rays, num_samples]. Weights assigned to each sampled color.
depth_map: [num_rays]. Estimated distance to object.
"""
raw2alpha
=
lambda
raw
,
dists
,
act_fn
=
F
.
relu
:
1.
-
torch
.
exp
(
-
act_fn
(
raw
)
*
dists
)
...
...
@@ -281,6 +318,36 @@ def render_rays(ray_batch,
raw_noise_std
=
0.
,
verbose
=
False
,
pytest
=
False
):
"""Volumetric rendering.
Args:
ray_batch: array of shape [batch_size, ...]. All information necessary
for sampling along a ray, including: ray origin, ray direction, min
dist, max dist, and unit-magnitude viewing direction.
network_fn: function. Model for predicting RGB and density at each point
in space.
network_query_fn: function used for passing queries to network_fn.
N_samples: int. Number of different times to sample along each ray.
retraw: bool. If True, include model's raw, unprocessed predictions.
lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
random points in time.
N_importance: int. Number of additional times to sample along each ray.
These samples are only passed to network_fine.
network_fine: "fine" network with same spec as network_fn.
white_bkgd: bool. If True, assume a white background.
raw_noise_std: ...
verbose: bool. If True, print more debugging info.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
disp_map: [num_rays]. Disparity map. 1 / depth.
acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
raw: [num_rays, num_samples, 4]. Raw predictions from model.
rgb0: See rgb_map. Output for coarse model.
disp0: See disp_map. Output for coarse model.
acc0: See acc_map. Output for coarse model.
z_std: [num_rays]. Standard deviation of distances along ray for each
sample.
"""
N_rays
=
ray_batch
.
shape
[
0
]
rays_o
,
rays_d
=
ray_batch
[:,
0
:
3
],
ray_batch
[:,
3
:
6
]
# [N_rays, 3] each
viewdirs
=
ray_batch
[:,
-
3
:]
if
ray_batch
.
shape
[
-
1
]
>
8
else
None
...
...
@@ -355,74 +422,114 @@ def config_parser():
import
configargparse
parser
=
configargparse
.
ArgumentParser
()
parser
.
add_argument
(
'--config'
,
is_config_file
=
True
,
help
=
'config file path'
)
parser
.
add_argument
(
"--expname"
,
type
=
str
,
help
=
'experiment name'
)
parser
.
add_argument
(
"--basedir"
,
type
=
str
,
default
=
'./logs/'
,
help
=
'where to store ckpts and logs'
)
parser
.
add_argument
(
"--datadir"
,
type
=
str
,
default
=
'./data/llff/fern'
,
help
=
'input data directory'
)
parser
.
add_argument
(
'--config'
,
is_config_file
=
True
,
help
=
'config file path'
)
parser
.
add_argument
(
"--expname"
,
type
=
str
,
help
=
'experiment name'
)
parser
.
add_argument
(
"--basedir"
,
type
=
str
,
default
=
'./logs/'
,
help
=
'where to store ckpts and logs'
)
parser
.
add_argument
(
"--datadir"
,
type
=
str
,
default
=
'./data/llff/fern'
,
help
=
'input data directory'
)
# training options
parser
.
add_argument
(
"--netdepth"
,
type
=
int
,
default
=
8
,
help
=
'layers in network'
)
parser
.
add_argument
(
"--netwidth"
,
type
=
int
,
default
=
256
,
help
=
'channels per layer'
)
parser
.
add_argument
(
"--netdepth_fine"
,
type
=
int
,
default
=
8
,
help
=
'layers in fine network'
)
parser
.
add_argument
(
"--netwidth_fine"
,
type
=
int
,
default
=
256
,
help
=
'channels per layer in fine network'
)
parser
.
add_argument
(
"--N_rand"
,
type
=
int
,
default
=
32
*
32
*
4
,
help
=
'batch size (number of random rays per gradient step)'
)
parser
.
add_argument
(
"--lrate"
,
type
=
float
,
default
=
5e-4
,
help
=
'learning rate'
)
parser
.
add_argument
(
"--lrate_decay"
,
type
=
int
,
default
=
250
,
help
=
'exponential learning rate decay (in 1000 steps)'
)
parser
.
add_argument
(
"--chunk"
,
type
=
int
,
default
=
1024
*
32
,
help
=
'number of rays processed in parallel, decrease if running out of memory'
)
parser
.
add_argument
(
"--netchunk"
,
type
=
int
,
default
=
1024
*
64
,
help
=
'number of pts sent through network in parallel, decrease if running out of memory'
)
parser
.
add_argument
(
"--no_batching"
,
action
=
'store_true'
,
help
=
'only take random rays from 1 image at a time'
)
parser
.
add_argument
(
"--no_reload"
,
action
=
'store_true'
,
help
=
'do not reload weights from saved ckpt'
)
parser
.
add_argument
(
"--ft_path"
,
type
=
str
,
default
=
None
,
help
=
'specific weights npy file to reload for coarse network'
)
parser
.
add_argument
(
"--netdepth"
,
type
=
int
,
default
=
8
,
help
=
'layers in network'
)
parser
.
add_argument
(
"--netwidth"
,
type
=
int
,
default
=
256
,
help
=
'channels per layer'
)
parser
.
add_argument
(
"--netdepth_fine"
,
type
=
int
,
default
=
8
,
help
=
'layers in fine network'
)
parser
.
add_argument
(
"--netwidth_fine"
,
type
=
int
,
default
=
256
,
help
=
'channels per layer in fine network'
)
parser
.
add_argument
(
"--N_rand"
,
type
=
int
,
default
=
32
*
32
*
4
,
help
=
'batch size (number of random rays per gradient step)'
)
parser
.
add_argument
(
"--lrate"
,
type
=
float
,
default
=
5e-4
,
help
=
'learning rate'
)
parser
.
add_argument
(
"--lrate_decay"
,
type
=
int
,
default
=
250
,
help
=
'exponential learning rate decay (in 1000 steps)'
)
parser
.
add_argument
(
"--chunk"
,
type
=
int
,
default
=
1024
*
32
,
help
=
'number of rays processed in parallel, decrease if running out of memory'
)
parser
.
add_argument
(
"--netchunk"
,
type
=
int
,
default
=
1024
*
64
,
help
=
'number of pts sent through network in parallel, decrease if running out of memory'
)
parser
.
add_argument
(
"--no_batching"
,
action
=
'store_true'
,
help
=
'only take random rays from 1 image at a time'
)
parser
.
add_argument
(
"--no_reload"
,
action
=
'store_true'
,
help
=
'do not reload weights from saved ckpt'
)
parser
.
add_argument
(
"--ft_path"
,
type
=
str
,
default
=
None
,
help
=
'specific weights npy file to reload for coarse network'
)
# rendering options
parser
.
add_argument
(
"--N_samples"
,
type
=
int
,
default
=
64
,
help
=
'number of coarse samples per ray'
)
parser
.
add_argument
(
"--N_importance"
,
type
=
int
,
default
=
0
,
help
=
'number of additional fine samples per ray'
)
parser
.
add_argument
(
"--perturb"
,
type
=
float
,
default
=
1.
,
help
=
'set to 0. for no jitter, 1. for jitter'
)
parser
.
add_argument
(
"--use_viewdirs"
,
action
=
'store_true'
,
help
=
'use full 5D input instead of 3D'
)
parser
.
add_argument
(
"--i_embed"
,
type
=
int
,
default
=
0
,
help
=
'set 0 for default positional encoding, -1 for none'
)
parser
.
add_argument
(
"--multires"
,
type
=
int
,
default
=
10
,
help
=
'log2 of max freq for positional encoding (3D location)'
)
parser
.
add_argument
(
"--multires_views"
,
type
=
int
,
default
=
4
,
help
=
'log2 of max freq for positional encoding (2D direction)'
)
parser
.
add_argument
(
"--raw_noise_std"
,
type
=
float
,
default
=
0.
,
help
=
'std dev of noise added to regularize sigma_a output, 1e0 recommended'
)
parser
.
add_argument
(
"--render_only"
,
action
=
'store_true'
,
help
=
'do not optimize, reload weights and render out render_poses path'
)
parser
.
add_argument
(
"--render_test"
,
action
=
'store_true'
,
help
=
'render the test set instead of render_poses path'
)
parser
.
add_argument
(
"--render_factor"
,
type
=
int
,
default
=
0
,
help
=
'downsampling factor to speed up rendering, set 4 or 8 for fast preview'
)
parser
.
add_argument
(
"--N_samples"
,
type
=
int
,
default
=
64
,
help
=
'number of coarse samples per ray'
)
parser
.
add_argument
(
"--N_importance"
,
type
=
int
,
default
=
0
,
help
=
'number of additional fine samples per ray'
)
parser
.
add_argument
(
"--perturb"
,
type
=
float
,
default
=
1.
,
help
=
'set to 0. for no jitter, 1. for jitter'
)
parser
.
add_argument
(
"--use_viewdirs"
,
action
=
'store_true'
,
help
=
'use full 5D input instead of 3D'
)
parser
.
add_argument
(
"--i_embed"
,
type
=
int
,
default
=
0
,
help
=
'set 0 for default positional encoding, -1 for none'
)
parser
.
add_argument
(
"--multires"
,
type
=
int
,
default
=
10
,
help
=
'log2 of max freq for positional encoding (3D location)'
)
parser
.
add_argument
(
"--multires_views"
,
type
=
int
,
default
=
4
,
help
=
'log2 of max freq for positional encoding (2D direction)'
)
parser
.
add_argument
(
"--raw_noise_std"
,
type
=
float
,
default
=
0.
,
help
=
'std dev of noise added to regularize sigma_a output, 1e0 recommended'
)
parser
.
add_argument
(
"--render_only"
,
action
=
'store_true'
,
help
=
'do not optimize, reload weights and render out render_poses path'
)
parser
.
add_argument
(
"--render_test"
,
action
=
'store_true'
,
help
=
'render the test set instead of render_poses path'
)
parser
.
add_argument
(
"--render_factor"
,
type
=
int
,
default
=
0
,
help
=
'downsampling factor to speed up rendering, set 4 or 8 for fast preview'
)
# dataset options
parser
.
add_argument
(
"--dataset_type"
,
type
=
str
,
default
=
'llff'
,
help
=
'options: llff / blender / deepvoxels'
)
parser
.
add_argument
(
"--testskip"
,
type
=
int
,
default
=
8
,
help
=
'will load 1/N images from test/val sets, useful for large datasets like deepvoxels'
)
parser
.
add_argument
(
"--dataset_type"
,
type
=
str
,
default
=
'llff'
,
help
=
'options: llff / blender / deepvoxels'
)
parser
.
add_argument
(
"--testskip"
,
type
=
int
,
default
=
8
,
help
=
'will load 1/N images from test/val sets, useful for large datasets like deepvoxels'
)
## deepvoxels flags
parser
.
add_argument
(
"--shape"
,
type
=
str
,
default
=
'greek'
,
help
=
'options : armchair / cube / greek / vase'
)
parser
.
add_argument
(
"--shape"
,
type
=
str
,
default
=
'greek'
,
help
=
'options : armchair / cube / greek / vase'
)
## blender flags
parser
.
add_argument
(
"--white_bkgd"
,
action
=
'store_true'
,
help
=
'set to render synthetic data on a white bkgd (always use for dvoxels)'
)
parser
.
add_argument
(
"--half_res"
,
action
=
'store_true'
,
help
=
'load blender synthetic data at 400x400 instead of 800x800'
)
parser
.
add_argument
(
"--white_bkgd"
,
action
=
'store_true'
,
help
=
'set to render synthetic data on a white bkgd (always use for dvoxels)'
)
parser
.
add_argument
(
"--half_res"
,
action
=
'store_true'
,
help
=
'load blender synthetic data at 400x400 instead of 800x800'
)
## llff flags
parser
.
add_argument
(
"--factor"
,
type
=
int
,
default
=
8
,
help
=
'downsample factor for LLFF images'
)
parser
.
add_argument
(
"--no_ndc"
,
action
=
'store_true'
,
help
=
'do not use normalized device coordinates (set for non-forward facing scenes)'
)
parser
.
add_argument
(
"--lindisp"
,
action
=
'store_true'
,
help
=
'sampling linearly in disparity rather than depth'
)
parser
.
add_argument
(
"--spherify"
,
action
=
'store_true'
,
help
=
'set for spherical 360 scenes'
)
parser
.
add_argument
(
"--llffhold"
,
type
=
int
,
default
=
8
,
help
=
'will take every 1/N images as LLFF test set, paper uses 8'
)
parser
.
add_argument
(
"--factor"
,
type
=
int
,
default
=
8
,
help
=
'downsample factor for LLFF images'
)
parser
.
add_argument
(
"--no_ndc"
,
action
=
'store_true'
,
help
=
'do not use normalized device coordinates (set for non-forward facing scenes)'
)
parser
.
add_argument
(
"--lindisp"
,
action
=
'store_true'
,
help
=
'sampling linearly in disparity rather than depth'
)
parser
.
add_argument
(
"--spherify"
,
action
=
'store_true'
,
help
=
'set for spherical 360 scenes'
)
parser
.
add_argument
(
"--llffhold"
,
type
=
int
,
default
=
8
,
help
=
'will take every 1/N images as LLFF test set, paper uses 8'
)
# logging/saving options
parser
.
add_argument
(
"--i_print"
,
type
=
int
,
default
=
100
,
help
=
'frequency of console printout and metric loggin'
)
parser
.
add_argument
(
"--i_img"
,
type
=
int
,
default
=
500
,
help
=
'frequency of tensorboard image logging'
)
parser
.
add_argument
(
"--i_weights"
,
type
=
int
,
default
=
10000
,
help
=
'frequency of weight ckpt saving'
)
parser
.
add_argument
(
"--i_testset"
,
type
=
int
,
default
=
50000
,
help
=
'frequency of testset saving'
)
parser
.
add_argument
(
"--i_video"
,
type
=
int
,
default
=
50000
,
help
=
'frequency of render_poses video saving'
)
parser
.
add_argument
(
"--i_print"
,
type
=
int
,
default
=
100
,
help
=
'frequency of console printout and metric loggin'
)
parser
.
add_argument
(
"--i_img"
,
type
=
int
,
default
=
500
,
help
=
'frequency of tensorboard image logging'
)
parser
.
add_argument
(
"--i_weights"
,
type
=
int
,
default
=
10000
,
help
=
'frequency of weight ckpt saving'
)
parser
.
add_argument
(
"--i_testset"
,
type
=
int
,
default
=
50000
,
help
=
'frequency of testset saving'
)
parser
.
add_argument
(
"--i_video"
,
type
=
int
,
default
=
50000
,
help
=
'frequency of render_poses video saving'
)
return
parser
def
train
():
parser
=
config_parser
()
args
=
parser
.
parse_args
()
# Load data
if
args
.
dataset_type
==
'llff'
:
...
...
@@ -453,7 +560,6 @@ def train():
far
=
1.
print
(
'NEAR FAR'
,
near
,
far
)
elif
args
.
dataset_type
==
'blender'
:
images
,
poses
,
render_poses
,
hwf
,
i_split
=
load_blender_data
(
args
.
datadir
,
args
.
half_res
,
args
.
testskip
)
print
(
'Loaded blender'
,
images
.
shape
,
render_poses
.
shape
,
hwf
,
args
.
datadir
)
...
...
@@ -467,7 +573,6 @@ def train():
else
:
images
=
images
[
...
,:
3
]
elif
args
.
dataset_type
==
'deepvoxels'
:
images
,
poses
,
render_poses
,
hwf
,
i_split
=
load_dv_data
(
scene
=
args
.
shape
,
...
...
@@ -481,7 +586,6 @@ def train():
near
=
hemi_R
-
1.
far
=
hemi_R
+
1.
else
:
print
(
'Unknown dataset type'
,
args
.
dataset_type
,
'exiting'
)
return
...
...
@@ -494,7 +598,6 @@ def train():
if
args
.
render_test
:
render_poses
=
np
.
array
(
poses
[
i_test
])
# Create log dir and copy the config file
basedir
=
args
.
basedir
expname
=
args
.
expname
...
...
@@ -509,7 +612,6 @@ def train():
with
open
(
f
,
'w'
)
as
file
:
file
.
write
(
open
(
args
.
config
,
'r'
)
.
read
())
# Create nerf model
render_kwargs_train
,
render_kwargs_test
,
start
,
grad_vars
,
optimizer
=
create_nerf
(
args
)
global_step
=
start
...
...
@@ -631,9 +733,6 @@ def train():
psnr0
=
mse2psnr
(
img_loss0
)
loss
.
backward
()
# NOTE: same as tf till here - 04/03/2020
optimizer
.
step
()
# NOTE: IMPORTANT!
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment