Commit 1bbeaf49 by lvzhengyang

implement a train logic

parent f81e4bdf
......@@ -2,3 +2,5 @@ place_parser
data/asap7
*/__pycache__
*.log
build_model/runs
build_model/weights
......@@ -153,6 +153,7 @@ def check_legal(data_dir, parsed_libs_dir, parsed_sdf_dir, save_dir, block_name)
pickle.dump(illegal_leaves, f)
f.close()
# DEPRECATED!
def create_graph5(data_dir, parsed_libs_dir, parsed_sdf_dir, save_dir, block_name):
node_names = np.load(os.path.join(data_dir, 'node_names.npy'))
node_x = np.load(os.path.join(data_dir, 'node_x.npy'))
......@@ -1123,6 +1124,7 @@ def create_graph6(data_dir, parsed_libs_dir, parsed_sdf_dir, save_dir, block_nam
g.ndata['n_slews'] = ndata['n_slews'].float()
g.ndata['nf'] = ndata['nf'].float()
g.ndata['n_is_timing_endpt'] = ndata['n_is_timing_endpt'].float()
g.ndata['n_net_delays'] = ndata['n_net_delays'].float()
g.edges['cell_out'].data['e_delay'] = torch.tensor(edges_delay['cell_out']).float()
g.edges['cell_out'].data['ef'] = torch.tensor(edges_features['cell_out']).float()
......
......@@ -8,6 +8,8 @@ import time
import argparse
import os
from sklearn.metrics import r2_score
from torch.utils.tensorboard import SummaryWriter
# import tee
from model import PredModel
......@@ -49,6 +51,8 @@ def load_data():
for block in blocks:
graph_path = os.path.join(dir_prefix, block, "parsed", f"{block}.graph.bin")
g = dgl.load_graphs(graph_path)[0][0].to('cuda')
g.ndata['n_net_delays_log'] = torch.log(0.0001 + g.ndata['n_net_delays']) + 7.6
g.edges['cell_out'].data['e_cell_delays_log'] = torch.log(0.0001 + g.edges['cell_out'].data['e_delay'])
topo, topo_time = gen_topo(g)
ts = {
'input_nodes': (g.ndata['nf'][:, 1] < 0.5).nonzero().flatten().type(torch.int64),
......@@ -60,31 +64,58 @@ def load_data():
'endpoints': (g.ndata['n_is_timing_endpt'] > 0.5).nonzero().flatten().type(torch.long),
'topo': topo,
}
mask = torch.zeros(g.nodes().size(0), dtype=torch.bool, device='cuda')
mask[ts['input_nodes']] = 1
ts['mask'] = mask
g.ndata['n_net_delays_log'][mask] = 0
data[block] = g, ts
return data
def train(model, data_train):
writer = SummaryWriter()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
# debug: load graph one by one
for k, (g, ts) in data_train.items():
print(f'-------- {k} --------')
try:
pred_net_delays, pred_cell_delays, pred_atslew = model(g, ts)
except:
print('error')
pdb.set_trace()
# batch_size = 5
batch_size = 1
batch_size = 5
for e in range(100000):
model.train()
train_loss_tot_net_delays, train_loss_tot_cell_delays, train_loss_tot_ats = 0, 0, 0
train_loss_epoch_net_delays, train_loss_epoch_cell_delays = 0, 0
optimizer.zero_grad()
for k, (g, ts) in random.sample(data_train.items(), batch_size):
pred_net_delays, pred_cell_delays, pred_atslew = model(g, ts)
pdb.set_trace()
# concern about delays only
# net delays are defined on pins
# add mask to select fanin pins only
pred_net_delays[ts['mask']] = 0
loss_net_delays = F.mse_loss(pred_net_delays, g.ndata['n_net_delays_log'])
train_loss_tot_net_delays += loss_net_delays.item()
train_loss_epoch_net_delays += loss_net_delays.item()
loss_cell_delays = F.mse_loss(pred_cell_delays, g.edges['cell_out'].data['e_cell_delays_log'])
train_loss_tot_cell_delays += loss_cell_delays.item()
train_loss_epoch_cell_delays = loss_cell_delays.item()
# loss_cell_delays = 0
(loss_net_delays + loss_cell_delays).backward()
optimizer.step()
writer.add_scalar('net_delays_loss/train: ', train_loss_epoch_net_delays, e)
writer.add_scalar('cell_delays_loss/train: ', train_loss_epoch_cell_delays, e)
train_loss_epoch_net_delays, train_loss_epoch_cell_delays = 0, 0
# log
if (e + 1) % 10 == 0:
print('epoch: {}, net_delay_loss (train): {:.6f}, cell_delay_loss (train): {:.6f}'.format(
e + 1, train_loss_tot_net_delays, train_loss_tot_cell_delays
))
# save model
if (e + 1) % 100 == 0:
print('-------- Save Model --------')
save_path = os.path.join('weights', f'{e}.pt')
torch.save(model.state_dict(), save_path)
writer.close()
if __name__ == "__main__":
data = load_data()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment