Commit 7158181e by Yen-Chen Lin

Better training interface

parent c3ccc0bd
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm, trange
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -673,7 +673,7 @@ def train(): ...@@ -673,7 +673,7 @@ def train():
rays_rgb = torch.Tensor(rays_rgb).to(device) rays_rgb = torch.Tensor(rays_rgb).to(device)
N_iters = 1000000 N_iters = 200000 + 1
print('Begin') print('Begin')
print('TRAIN views are', i_train) print('TRAIN views are', i_train)
print('TEST views are', i_test) print('TEST views are', i_test)
...@@ -682,7 +682,7 @@ def train(): ...@@ -682,7 +682,7 @@ def train():
# Summary writers # Summary writers
# writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
for i in range(start, N_iters): for i in trange(start, N_iters):
time0 = time.time() time0 = time.time()
# Sample random ray batch # Sample random ray batch
...@@ -745,7 +745,7 @@ def train(): ...@@ -745,7 +745,7 @@ def train():
################################ ################################
dt = time.time()-time0 dt = time.time()-time0
print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
##### end ##### ##### end #####
# Rest is logging # Rest is logging
...@@ -784,11 +784,13 @@ def train(): ...@@ -784,11 +784,13 @@ def train():
print('Saved test set') print('Saved test set')
"""
if i%args.i_print==0 or i < 10: if i%args.i_print==0 or i < 10:
tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
"""
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy()) print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
print('iter time {:.05f}'.format(dt)) print('iter time {:.05f}'.format(dt))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print): with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
tf.contrib.summary.scalar('loss', loss) tf.contrib.summary.scalar('loss', loss)
tf.contrib.summary.scalar('psnr', psnr) tf.contrib.summary.scalar('psnr', psnr)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment