Commit d00bbe81 by lvzhengyang

test model

parent f4bb5ddb
......@@ -7,6 +7,8 @@ import pickle
import math
import pdb
from torch.utils.data import Dataset, DataLoader
class MLP(torch.nn.Module):
def __init__(self, *sizes, batchnorm=False, dropout=False):
super().__init__()
......@@ -100,6 +102,25 @@ def load_cell_delay_data():
f.close()
return data
class myDataset(Dataset):
def __init__(self, data, libcell, key):
self.delays = data['delays'][libcell][key]
self.topos = data['topos'][libcell][key]
def __len__(self):
return len(self.topos)
def __getitem__(self, idx):
return self.topos[idx][0][0], self.topos[idx][0][1], self.topos[idx][1][0], self.topos[idx][1][1], self.delays[idx]
def collate_fn(data):
input_data = []
labels = []
for dd in data:
input_data.append((dd[0], dd[1], dd[2], dd[3]))
labels.append(dd[4])
return input_data, labels
def preprocess(data):
libcell = 'INVx1_ASAP7_75t_R'
key = 'A-Y'
......@@ -153,8 +174,38 @@ def preprocess(data):
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(), save_path)
def train(loader):
model = CellDelayPred(7, 4, 32)
model.cuda()
model = torch.nn.DataParallel(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
for batch_idx, (topo, delay) in enumerate(loader):
num_data = len(topo)
batch_loss = torch.zeros(num_data, device='cuda')
for di in range(num_data):
fanin_topo = torch.from_numpy(topo[di][0]).cuda().float().unsqueeze_(0)
fanin_id = topo[di][1]
fanout_topo = torch.from_numpy(topo[di][2]).cuda().float().unsqueeze_(0)
fanout_id = topo[di][3]
pred = model(fanin_topo, fanout_topo, fanin_id, fanout_id)
truth = torch.from_numpy(delay[di]).cuda().float().unsqueeze_(0)
batch_loss[di] = F.mse_loss(pred, truth)
pdb.set_trace()
if __name__ == '__main__':
data = load_cell_delay_data()
# naive training
preprocess(data)
"""
# use DataLoader
libcell = 'INVx1_ASAP7_75t_R'
key = 'A-Y'
dataset = myDataset(data, libcell, key)
loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
train(loader)
"""
import torch
import numpy as np
import dgl
import torch.nn.functional as F
import random
import pdb
import time
import argparse
import os
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
from self_att import *
def load_model(model_path):
model = CellDelayPred(7, 4, 32)
model.load_state_dict(torch.load(model_path))
model.cuda()
return model
def test(data, model):
libcell = 'INVx1_ASAP7_75t_R'
key = 'A-Y'
delays = np.stack(data['delays'][libcell][key])
delays_log = np.log(delays)
topos = data['topos'][libcell][key]
fanin_topos, fanin_ids, fanout_topos, fanout_ids = [], [], [], []
for topo in topos:
fanin_topos.append(torch.tensor(topo[0][0]).float())
fanin_ids.append(topo[0][1])
fanout_topos.append(torch.tensor(topo[1][0]).float())
fanout_ids.append(topo[1][1])
num_data = len(topos)
pred_delays_log = np.zeros((num_data, 4))
model.eval()
with torch.no_grad():
for di in range(num_data):
fanin = fanin_topos[di]
fanin = fanin.cuda()
fanin.unsqueeze_(0)
fanout = fanout_topos[di]
fanout = fanout.cuda()
fanout.unsqueeze_(0)
pred = model(fanin, fanout, fanin_ids[di], fanout_ids[di])
pred_delays_log[di] = pred.cpu().numpy()
corner_names = {0: 'ER', 1: 'EF', 2: 'LR', 3: 'LF'}
pred_delays = np.exp(pred_delays_log)
for corner in range(4):
tt = delays[:, corner]
pp = pred_delays[:, corner]
minv = min(tt.min(), pp.min()) - 0.2
maxv = max(tt.max(), pp.max()) + 0.2
maxv = min(2000, maxv)
plt.axis("square")
plt.title(f'cell delay prediction ({corner_names[corner]}) of libcell {libcell}')
plt.xlabel('Truth/ns')
plt.ylabel('Predicted/ns')
plt.xlim(minv - 10, maxv + 10)
plt.ylim(minv - 10, maxv + 10)
plt.axline((minv, minv), (maxv, maxv), color='r')
# plt.axline((500, minv), (500, maxv), color='black', linestyle='-.')
plt.scatter(tt, pp, s=10, c='b')
save_dir = os.path.join('figures', libcell)
os.makedirs(save_dir, exist_ok=True)
plt.savefig(os.path.join(save_dir, f'{libcell}.{corner_names[corner]}.png'))
plt.clf()
# draw delay distribution
diff = pp - tt
# filter diff
diff = diff[np.abs(diff) <= 200]
print(f"{corner_names[corner]} remove {pp.size - diff.size} / {pp.size} large diff points")
plt.hist(diff, bins=100)
plt.xlabel('Diff')
plt.ylabel('Frequency')
plt.xlim(-200, 200)
plt.title(f'cell delay pred err distr ({corner_names[corner]}) of libcell {libcell}')
plt.savefig(os.path.join(save_dir, f'{libcell}.{corner_names[corner]}.dist.png'))
plt.clf()
if __name__ == '__main__':
data = load_cell_delay_data()
model_path = './weights/e-109-loss-1622.447.pt'
model = load_model(model_path)
test(data, model)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment