Commit 339b71e6 by songxinkai

analysis.py & nn_analysis.py

parent 62da0b21
...@@ -7,3 +7,6 @@ ...@@ -7,3 +7,6 @@
*.dae *.dae
data/* data/*
logs/* logs/*
.*swp
ret-*
.nfs*
#!/workspace/S/songxinkai/local/anaconda3/bin/python
import os
import sys
import json
import random
import numpy as np
from struct import unpack, pack
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
H = 400
W = 400
IMAGE = 6
C_SAM = 64
F_SAM = 128
RAY_BATCH = 32768
EMB_LEN = 90
data_dir = "dump_data"
def read_data(filename, skip_lines):
with open(filename, 'r') as f:
lines = f.readlines()
data = [[float(x) for x in line.strip().split("[")[-1].split("]")[0].split(", ")] for line_id, line in enumerate(lines) if line_id >= skip_lines]
return data
def quant(input, quant_step, q_min = -128, q_max = 127):
quant_step = float(quant_step)
q_min = int(q_min)
q_max = int(q_max)
output = [round(x / quant_step) for x in input]
for i in range(len(output)):
output[i] = q_min if output[i] < q_min else output[i]
output[i] = q_max if output[i] > q_max else output[i]
return output
## Quantization
if False:
i_min = 1.
i_max = -1.
quant_bit = 8
inputs = []
# for im_id in range(IMAGE):
for ray_id in range(0, H*W, RAY_BATCH):
print ("Reading %d"%ray_id)
n_ray = min(RAY_BATCH, H*W - ray_id)
input = read_data(os.path.join(data_dir, "0", "%d"%ray_id, "embedded.txt"), C_SAM*n_ray)
inputs += input
idx = random.sample(range(len(inputs)), H*W*10)
for i in idx:
t_min = min(inputs[i])
t_max = max(inputs[i])
i_min = t_min if t_min < i_min else i_min
i_max = t_max if t_max > i_max else i_max
abs_max = i_max if i_max + i_min > 0 else -i_min
quant_step = 2 * abs_max / 2**(quant_bit)
print (i_min, i_max, quant_step) # -2.6404643058776855 2.193082332611084 0.02062862738966942
# for im_id in range(IMAGE):
with open (os.path.join(data_dir, "0", "embedded_quant.bin"), 'wb') as f:
for i,input in enumerate(inputs):
if i % (192 * 400) == 0:
print ("Quantizing %d"%(i/(192*400)))
i_q = quant(input, quant_step, -128, 127)
f.write(pack("%dh"%(len(input)), *i_q))
## Delta & Stat
if False:
inputs = [[[[] for z in range(C_SAM+F_SAM)] for w in range(W)] for h in range(H)]
with open (os.path.join(data_dir, "0", "embedded_quant.bin"), 'rb') as f:
for h in range(H):
print ("Reading %d/%d"%(h, H))
for w in range(W):
for z in range(C_SAM + F_SAM):
inputs[h][w][z] += unpack("%dh"%(EMB_LEN), f.read(EMB_LEN * 2))
if sys.argv[1] == "h":
h_idx = sorted(random.sample(range(H-1), int(H / 2)))
h_stat = [0 for i in range(511)]
for h in h_idx:
w_idx = random.sample(range(W), int(W / 2))
print ("Delta H: %d/%d"%(h, H))
for w in w_idx:
for z in range(C_SAM + F_SAM):
for i in range(EMB_LEN):
delta = inputs[h+1][w][z][i] - inputs[h][w][z][i]
h_stat[delta] += 1
print ("h_stat:", h_stat)
with open (os.path.join(data_dir, "0", "embedded_quant_delta_h.bin"), 'wb') as f:
f.write(pack("%di"%(511), *h_stat))
elif sys.argv[1] == "w":
h_idx = sorted(random.sample(range(H), int(H / 2)))
w_stat = [0 for i in range(511)]
for h in h_idx:
w_idx = random.sample(range(W-1), int(W / 2))
print ("Delta W: %d/%d"%(h, H))
for w in w_idx:
for z in range(C_SAM + F_SAM):
for i in range(EMB_LEN):
delta = inputs[h][w+1][z][i] - inputs[h][w][z][i]
w_stat[delta] += 1
print ("w_stat:", w_stat)
with open(os.path.join(data_dir, "0", "embedded_quant_delta_w.bin"), 'wb') as f:
f.write(pack("%di"%(511), *w_stat))
elif sys.argv[1] == "z":
h_idx = sorted(random.sample(range(H), int(H / 2)))
z_stat = [0 for i in range(511)]
for h in h_idx:
print ("Delta Z: %d/%d"%(h, H))
w_idx = random.sample(range(W), int(W / 2))
for w in w_idx:
for z in range(C_SAM + F_SAM - 1):
for i in range(EMB_LEN):
delta = inputs[h][w][z+1][i] - inputs[h][w][z][i]
z_stat[delta] += 1
print ("z_stat:", z_stat)
with open(os.path.join(data_dir, "0", "embedded_quant_delta_z.bin"), 'wb') as f:
f.write(pack("%di"%(511), *z_stat))
if True:
h_stat = []
w_stat = []
z_stat = []
with open(os.path.join(data_dir, "0", "embedded_quant_delta_h.bin"), 'rb') as f:
h_stat = unpack("%di"%(511), f.read(511*4))
with open(os.path.join(data_dir, "0", "embedded_quant_delta_w.bin"), 'rb') as f:
w_stat = unpack("%di"%(511), f.read(511*4))
with open(os.path.join(data_dir, "0", "embedded_quant_delta_z.bin"), 'rb') as f:
z_stat = unpack("%di"%(511), f.read(511*4))
h_sum = sum(h_stat)
w_sum = sum(w_stat)
z_sum = sum(z_stat)
res = {"h":{}, "w":{}, "z":{}, "all":{}}
res['h'][0] = h_stat[255]/h_sum
res['w'][0] = w_stat[255]/w_sum
res['z'][0] = z_stat[255]/z_sum
res['all'][0] = (h_stat[255]+w_stat[255]+z_stat[255])/(h_sum+w_sum+z_sum)
for i in range(9):
low = max(0, -2**i + 255)
high = min(510, 2**i-1 + 255)
h_count = sum(h_stat[low : high+1])
w_count = sum(w_stat[low : high+1])
z_count = sum(z_stat[low : high+1])
res['h'][2**i] = h_count/h_sum
res['w'][2**i] = w_count/w_sum
res['z'][2**i] = z_count/z_sum
res['all'][2**i] = (h_count+w_count+z_count)/(h_sum+w_sum+z_sum)
print (res)
plt.figure(dpi=300, figsize=(16,8))
# plt.rcParams['font.sans-serif']=['SimHei']
# plt.rcParams['axes.unicode_minus']=False
idx = ["[0]"]# + ["%d"%2**i for i in range(9)]
for i in range(9):
low = max(-255, -2**i)
high = min(255, 2**i-1)
idx.append("[%d~%d]"%(low, high))
h_count = [res['h'][0]] + [res['h'][2**i] for i in range(9)]
w_count = [res['w'][0]] + [res['w'][2**i] for i in range(9)]
z_count = [res['z'][0]] + [res['z'][2**i] for i in range(9)]
all_count = [res['all'][0]] + [res['all'][2**i] for i in range(9)]
x = np.arange(len(idx))
width = 0.2
plt.bar(x-width/2, h_count, label='H', alpha=0.6, width=width)
plt.bar(x-3*width/2, w_count, label='W', alpha=0.6, width=width)
plt.bar(x+width/2, z_count, label='Z', alpha=0.6, width=width)
plt.bar(x+3*width/2, all_count, label='Average', alpha=0.6, width=width)
plt.legend()
plt.xlabel('Density')
plt.ylabel('Count')
# plt.ylim(0, 1.0)
# plt.yscale("log")
# plt.xscale("log")
print (idx)
print (h_count)
plt.xticks(x,idx)
# for i in range(len(idx)):
# plt.text(i-0.4, count[i]+0.02,"%.3f"%count[i],va='center')
# plt.title('Density')
plt.tight_layout()
plt.show()
plt.savefig("h_delta.png")
dump_data/3/98304/outputs_flat.txt 8388608
dump_data/3/98304/pts.txt 6291456
dump_data/3/98304/z_samples.txt 32768
dump_data/3/98304/viewdirs.txt 65536
dump_data/3/98304/embedded.txt 8388608
dump_data/3/65536/outputs_flat.txt 8388608
dump_data/3/65536/pts.txt 6291456
dump_data/3/65536/z_samples.txt 32768
dump_data/3/65536/viewdirs.txt 65536
dump_data/3/65536/embedded.txt 8388608
dump_data/3/32768/outputs_flat.txt 8388608
dump_data/3/32768/pts.txt 6291456
dump_data/3/32768/z_samples.txt 32768
dump_data/3/32768/viewdirs.txt 65536
dump_data/3/32768/embedded.txt 8388608
dump_data/3/131072/outputs_flat.txt 7405568
dump_data/3/131072/pts.txt 5554176
dump_data/3/131072/z_samples.txt 28928
dump_data/3/131072/viewdirs.txt 57856
dump_data/3/131072/embedded.txt 7405568
dump_data/3/0/outputs_flat.txt 8388608
dump_data/3/0/pts.txt 6291456
dump_data/3/0/z_samples.txt 32768
dump_data/3/0/viewdirs.txt 65536
dump_data/3/0/embedded.txt 8388608
dump_data/1/98304/outputs_flat.txt 8388608
dump_data/1/98304/pts.txt 6291456
dump_data/1/98304/z_samples.txt 32768
dump_data/1/98304/viewdirs.txt 65536
dump_data/1/98304/embedded.txt 8388608
dump_data/1/65536/outputs_flat.txt 8388608
dump_data/1/65536/pts.txt 6291456
dump_data/1/65536/z_samples.txt 32768
dump_data/1/65536/viewdirs.txt 65536
dump_data/1/65536/embedded.txt 8388608
dump_data/1/32768/outputs_flat.txt 8388608
dump_data/1/32768/pts.txt 6291456
dump_data/1/32768/z_samples.txt 32768
dump_data/1/32768/viewdirs.txt 65536
dump_data/1/32768/embedded.txt 8388608
dump_data/1/131072/outputs_flat.txt 7405568
dump_data/1/131072/pts.txt 5554176
dump_data/1/131072/z_samples.txt 28928
dump_data/1/131072/viewdirs.txt 57856
dump_data/1/131072/embedded.txt 7405568
dump_data/1/0/outputs_flat.txt 8388608
dump_data/1/0/pts.txt 6291456
dump_data/1/0/z_samples.txt 32768
dump_data/1/0/viewdirs.txt 65536
dump_data/1/0/embedded.txt 8388608
dump_data/5/98304/outputs_flat.txt 8388608
dump_data/5/98304/pts.txt 6291456
dump_data/5/98304/z_samples.txt 32768
dump_data/5/98304/viewdirs.txt 65536
dump_data/5/98304/embedded.txt 8388608
dump_data/5/65536/outputs_flat.txt 8388608
dump_data/5/65536/pts.txt 6291456
dump_data/5/65536/z_samples.txt 32768
dump_data/5/65536/viewdirs.txt 65536
dump_data/5/65536/embedded.txt 8388608
dump_data/5/32768/outputs_flat.txt 8388608
dump_data/5/32768/pts.txt 6291456
dump_data/5/32768/z_samples.txt 32768
dump_data/5/32768/viewdirs.txt 65536
dump_data/5/32768/embedded.txt 8388608
dump_data/5/131072/outputs_flat.txt 7405568
dump_data/5/131072/pts.txt 5554176
dump_data/5/131072/z_samples.txt 28928
dump_data/5/131072/viewdirs.txt 57856
dump_data/5/131072/embedded.txt 7405568
dump_data/5/0/outputs_flat.txt 8388608
dump_data/5/0/pts.txt 6291456
dump_data/5/0/z_samples.txt 32768
dump_data/5/0/viewdirs.txt 65536
dump_data/5/0/embedded.txt 8388608
dump_data/6/98304/outputs_flat.txt 2097152
dump_data/6/98304/pts.txt 6291456
dump_data/6/98304/z_samples.txt 32768
dump_data/6/98304/viewdirs.txt 65536
dump_data/6/98304/embedded.txt 8388608
dump_data/6/65536/outputs_flat.txt 8388608
dump_data/6/65536/pts.txt 6291456
dump_data/6/65536/z_samples.txt 32768
dump_data/6/65536/viewdirs.txt 65536
dump_data/6/65536/embedded.txt 8388608
dump_data/6/32768/outputs_flat.txt 8388608
dump_data/6/32768/pts.txt 6291456
dump_data/6/32768/z_samples.txt 32768
dump_data/6/32768/viewdirs.txt 65536
dump_data/6/32768/embedded.txt 8388608
dump_data/6/0/outputs_flat.txt 8388608
dump_data/6/0/pts.txt 6291456
dump_data/6/0/z_samples.txt 32768
dump_data/6/0/viewdirs.txt 65536
dump_data/6/0/embedded.txt 8388608
dump_data/2/98304/outputs_flat.txt 8388608
dump_data/2/98304/pts.txt 6291456
dump_data/2/98304/z_samples.txt 32768
dump_data/2/98304/viewdirs.txt 65536
dump_data/2/98304/embedded.txt 8388608
dump_data/2/65536/outputs_flat.txt 8388608
dump_data/2/65536/pts.txt 6291456
dump_data/2/65536/z_samples.txt 32768
dump_data/2/65536/viewdirs.txt 65536
dump_data/2/65536/embedded.txt 8388608
dump_data/2/32768/outputs_flat.txt 8388608
dump_data/2/32768/pts.txt 6291456
dump_data/2/32768/z_samples.txt 32768
dump_data/2/32768/viewdirs.txt 65536
dump_data/2/32768/embedded.txt 8388608
dump_data/2/131072/outputs_flat.txt 7405568
dump_data/2/131072/pts.txt 5554176
dump_data/2/131072/z_samples.txt 28928
dump_data/2/131072/viewdirs.txt 57856
dump_data/2/131072/embedded.txt 7405568
dump_data/2/0/outputs_flat.txt 8388608
dump_data/2/0/pts.txt 6291456
dump_data/2/0/z_samples.txt 32768
dump_data/2/0/viewdirs.txt 65536
dump_data/2/0/embedded.txt 8388608
dump_data/0/98304/outputs_flat.txt 8388608
dump_data/0/98304/pts.txt 6291456
dump_data/0/98304/z_samples.txt 32768
dump_data/0/98304/viewdirs.txt 65536
dump_data/0/98304/embedded.txt 8388608
dump_data/0/65536/outputs_flat.txt 8388608
dump_data/0/65536/pts.txt 6291456
dump_data/0/65536/z_samples.txt 32768
dump_data/0/65536/viewdirs.txt 65536
dump_data/0/65536/embedded.txt 8388608
dump_data/0/32768/outputs_flat.txt 8388608
dump_data/0/32768/pts.txt 6291456
dump_data/0/32768/z_samples.txt 32768
dump_data/0/32768/viewdirs.txt 65536
dump_data/0/32768/embedded.txt 8388608
dump_data/0/131072/outputs_flat.txt 7405568
dump_data/0/131072/pts.txt 5554176
dump_data/0/131072/z_samples.txt 28928
dump_data/0/131072/viewdirs.txt 57856
dump_data/0/131072/embedded.txt 7405568
dump_data/0/0/outputs_flat.txt 8388608
dump_data/0/0/pts.txt 6291456
dump_data/0/0/z_samples.txt 32768
dump_data/0/0/viewdirs.txt 65536
dump_data/0/0/embedded.txt 8388608
dump_data/4/98304/outputs_flat.txt 8388608
dump_data/4/98304/pts.txt 6291456
dump_data/4/98304/z_samples.txt 32768
dump_data/4/98304/viewdirs.txt 65536
dump_data/4/98304/embedded.txt 8388608
dump_data/4/65536/outputs_flat.txt 8388608
dump_data/4/65536/pts.txt 6291456
dump_data/4/65536/z_samples.txt 32768
dump_data/4/65536/viewdirs.txt 65536
dump_data/4/65536/embedded.txt 8388608
dump_data/4/32768/outputs_flat.txt 8388608
dump_data/4/32768/pts.txt 6291456
dump_data/4/32768/z_samples.txt 32768
dump_data/4/32768/viewdirs.txt 65536
dump_data/4/32768/embedded.txt 8388608
dump_data/4/131072/outputs_flat.txt 7405568
dump_data/4/131072/pts.txt 5554176
dump_data/4/131072/z_samples.txt 28928
dump_data/4/131072/viewdirs.txt 57856
dump_data/4/131072/embedded.txt 7405568
dump_data/4/0/outputs_flat.txt 8388608
dump_data/4/0/pts.txt 6291456
dump_data/4/0/z_samples.txt 32768
dump_data/4/0/viewdirs.txt 65536
dump_data/4/0/embedded.txt 8388608
#!/bin/bash
#- Job parameters
# (TODO)
# Please modify job name
#SBATCH -J test # The job name
#SBATCH -o ret-%j.out # Write the standard output to file named 'ret-<job_number>.out'
#SBATCH -e ret-%j.err # Write the standard error to file named 'ret-<job_number>.err'
#- Needed resources
# (TODO)
# Please modify your requirements
#SBATCH -p nv-gpu#,nv-gpu-hw # Submit to 'nv-gpu' and 'nv-gpu-hw' Partitiion
#SBATCH -t 0-8:00:00 # Run for a maximum time of 0 days, 12 hours, 00 mins, 00 secs
#SBATCH --nodes=1 # Request N nodes
#SBATCH --gres=gpu:4 # Request M GPU per node
#SBATCH --gres-flags=enforce-binding # CPU-GPU Affinity
#SBATCH --constraint="Volta" # Request GPU Type: Volta(V100 or V100S) or RTX8000
###
### The system will alloc 8 cores per gpu by default.
### If you need more or less, use following:
### #SBATCH --cpus-per-task=K # Request K cores
###
#SBATCH --qos=gpu-short # Request QOS Type
#- Operstions
echo "Job start at $(date "+%Y-%m-%d %H:%M:%S")"
echo "Job run at:"
echo "$(hostnamectl)"
#- Load environments
source /tools/module_env.sh
module list # list modules loaded by default
##- tools
module load cluster-tools/v1.0
module load cmake/3.15.7
module load git/2.17.1
module load vim/8.1.2424
##- language
module load python3/3.6.8
##- cuda
module load cuda-cudnn/11.0-8.0.4
##- virtualenv
# source xxxxx/activate
#- Log information
echo $(module list) # list modules loaded
echo $(which gcc)
echo $(which python)
echo $(which python3)
cluster-quota # nas quota
nvidia-smi --format=csv --query-gpu=name,driver_version,power.limit # gpu info
echo "Use GPU ${CUDA_VISIBLE_DEVICES}$" # which gpus
#- Warning! Please not change your CUDA_VISIBLE_DEVICES
#- in `.bashrc`, `env.sh`, or your job script
#- Job step
# sleep 28800
sleep 108000
#- End
echo "Job end at $(date "+%Y-%m-%d %H:%M:%S")"
/home/S/songxinkai/lustre/big_data/nerf/
\ No newline at end of file
...@@ -35,7 +35,7 @@ def pose_spherical(theta, phi, radius): ...@@ -35,7 +35,7 @@ def pose_spherical(theta, phi, radius):
def load_blender_data(basedir, half_res=False, testskip=1): def load_blender_data(basedir, half_res=False, testskip=1):
splits = ['train', 'val', 'test'] splits = ['val', 'test', 'train']
metas = {} metas = {}
for s in splits: for s in splits:
with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
......
#!/workspace/S/songxinkai/local/anaconda3/bin/python
import time
import torch
import torch.nn as nn
import json
import torch.nn.functional as F
import numpy as np
class NeRF(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
"""
"""
super(NeRF, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_views = input_ch_views
self.skips = skips
self.use_viewdirs = use_viewdirs
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
### Implementation according to the paper
# self.views_linears = nn.ModuleList(
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
if use_viewdirs:
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W//2, 3)
else:
self.output_linear = nn.Linear(W, output_ch)
def forward(self, x):
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
h = input_pts
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
h = torch.cat([input_pts, h], -1)
if self.use_viewdirs:
alpha = self.alpha_linear(h)
feature = self.feature_linear(h)
h = torch.cat([feature, input_views], -1)
for i, l in enumerate(self.views_linears):
h = self.views_linears[i](h)
h = F.relu(h)
rgb = self.rgb_linear(h)
outputs = torch.cat([rgb, alpha], -1)
else:
outputs = self.output_linear(h)
return outputs
def load_weights_from_keras(self, weights):
assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
# Load pts_linears
for i in range(self.D):
idx_pts_linears = 2 * i
self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears]))
self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1]))
# Load feature_linear
idx_feature_linear = 2 * self.D
self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear]))
self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1]))
# Load views_linears
idx_views_linears = 2 * self.D + 2
self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears]))
self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1]))
# Load rgb_linear
idx_rbg_linear = 2 * self.D + 4
self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear]))
self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1]))
# Load alpha_linear
idx_alpha_linear = 2 * self.D + 6
self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear]))
self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1]))
if __name__ == "__main__":
H = 400
W = 400
C_SAM = 64
F_SAM = 128
EMB_LEN = 90
ckpt_path = "./logs/blender_paper_lego/200000.tar"
quant_step = 0.02062862738966942
device = torch.device("cuda")
model= NeRF(D=8, W=256,
input_ch=63, output_ch=5, skips=[4],
input_ch_views=27, use_viewdirs=True).to(device)
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['network_fine_state_dict'])
# model = torch.nn.DataParallel(model.cuda(), device_ids=[0, 1, 2, 3])
# inputs = [[[[] for z in range(C_SAM+F_SAM)] for w in range(W)] for h in range(H)]
# with open (os.path.join(data_dir, "0", "embedded_quant.bin"), 'rb') as f:
# for h in range(H):
# print ("Reading %d/%d"%(h, H))
# for w in range(W):
# for z in range(C_SAM + F_SAM):
# inputs[h][w][z] += torch.tensor(unpack("%dh"%(EMB_LEN), f.read(EMB_LEN * 2))) * quant_step
batch_size = 4096
N = 10000
inputs = torch.rand((N*batch_size, 90)).to(device)
torch.cuda.synchronize()
start = time.time()
print ("start inference")
for i in range(N):
output = model(inputs[i*batch_size:i*batch_size+batch_size])
if i % 1000 == 0:
print (i, output.shape)
print ("time:", time.time() - start)
import torch import torch
# torch.autograd.set_detect_anomaly(True) # torch.autograd.set_detect_anomaly(True)
import torch.nn as nn import torch.nn as nn
import json
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
...@@ -151,14 +152,29 @@ class NeRF(nn.Module): ...@@ -151,14 +152,29 @@ class NeRF(nn.Module):
# Ray helpers # Ray helpers
def get_rays(H, W, K, c2w): def get_rays(H, W, K, c2w):
# print ("==== H, W, K, c2w =====")
# print (H, W, K, c2w)
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
# print ("==== i,j =====")
# print (i.shape)
# print (j.shape)
# print (i, j)
i = i.t() i = i.t()
j = j.t() j = j.t()
# print ("==== i.t(), j.t() =====")
# print (i.shape)
# print (j.shape)
# print (i, j)
dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1) dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
# print ("========= dirs.shape ========")
# print (dirs.shape)
# Rotate ray directions from camera frame to the world frame # Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays. # Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3,-1].expand(rays_d.shape) rays_o = c2w[:3,-1].expand(rays_d.shape)
# print ("==== rays_o, rays_d ====")
# print (rays_o.shape)
# print (rays_d.shape)
return rays_o, rays_d return rays_o, rays_d
...@@ -195,6 +211,10 @@ def ndc_rays(H, W, focal, near, rays_o, rays_d): ...@@ -195,6 +211,10 @@ def ndc_rays(H, W, focal, near, rays_o, rays_d):
# Hierarchical sampling (section 5.2) # Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False): def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
# Get pdf # Get pdf
# print ("=============== sample_pdf =====================")
# print ("bins", bins.shape, "weights", weights.shape, torch.sum(weights), N_samples, det, pytest)
# print (bins)
# print (weights)
weights = weights + 1e-5 # prevent nans weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True) pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1) cdf = torch.cumsum(pdf, -1)
...@@ -236,4 +256,8 @@ def sample_pdf(bins, weights, N_samples, det=False, pytest=False): ...@@ -236,4 +256,8 @@ def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
t = (u-cdf_g[...,0])/denom t = (u-cdf_g[...,0])/denom
samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
# print (samples.shape)
# with open("tmp/samples.txt", 'a') as f:
# json.dump(samples.tolist(), f)
# f.write("\n")
return samples return samples
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment