Commit a050c444 by lvzhengyang

train super-edge

parent d00bbe81
...@@ -2,7 +2,5 @@ place_parser ...@@ -2,7 +2,5 @@ place_parser
data/asap7 data/asap7
*/__pycache__ */__pycache__
*.log *.log
build_model/runs */runs
build_model/weights */weights
cell_delay/runs
cell_delay/weights
...@@ -1188,6 +1188,9 @@ if __name__ == '__main__': ...@@ -1188,6 +1188,9 @@ if __name__ == '__main__':
save_dir = os.path.join(techlib_dir, "parsed_lib") save_dir = os.path.join(techlib_dir, "parsed_lib")
parse_libs(libs, save_dir) parse_libs(libs, save_dir)
# debug
blocks = "gcd".split()
for block in blocks: for block in blocks:
print(f'-------- {block} --------') print(f'-------- {block} --------')
block_dir = os.path.join(raw_data_dir, tag, block) block_dir = os.path.join(raw_data_dir, tag, block)
......
TODO get pin layer
DONE rebuild pin_cap, currently some pin_cap is 0
DONE get gates driving strength
\ No newline at end of file
import numpy as np
import pickle
import json
import os
from utils import *
import torch
from torch.utils.data import Dataset, DataLoader
import pdb
def load_topos(data_dir, block):
with open(os.path.join(data_dir, f'{block}.topo.dict'), 'rb') as f:
topos = pickle.load(f)
f.close()
return topos
def build_dataset(topos, dataset):
# dataset is classified by num_fanout
for topo in topos:
num_fanout = len(topo['fanin'])
if not num_fanout in dataset:
dataset[num_fanout] = {
'pin_features': [],
'transition': [],
'transition_incr': [],
'delay': []
}
# make all the pins as a matrix
# each row is a pin, with the last row as the fanout pin
# pin_features {name: #dim (descrption)}
# dir: 1 (0 for input, 1 for output)
# caps: 4
# loc: 2
# layer: (TO BE IMPLEMENTED)
# libcell_driving strength: 1
# num_fanout: 1
# * tot_dim = 1 + 4 + 2 + 1 + 1 = 9
dtype = np.float32
pin_features = np.zeros([num_fanout + 1, 9], dtype=dtype)
def get_pin_feat(pin_dict):
pin_feat = np.zeros(9)
if pin_dict['dir'] == 'out':
pin_feat[0] = 1
else:
pin_feat[0] = 0
pin_feat[1:5] = pin_dict['caps']
pin_feat[5:7] = pin_dict['loc']
pin_feat[7] = pin_dict['num_fanout']
pin_feat[8] = pin_dict['libcell_drive_strength']
return pin_feat
for i in range(num_fanout):
pin_dict = topo['fanin'][i]
pin_features[i] = get_pin_feat(pin_dict)
pin_features[-1] = get_pin_feat(topo['fanout'][0])
transition = np.zeros([num_fanout + 1, 4], dtype=dtype)
transition_incr = np.zeros([num_fanout + 1, 4], dtype=dtype)
delay = np.zeros([num_fanout + 1, 4], dtype=dtype)
for i in range(num_fanout):
pin_dict = topo['fanin'][i]
transition[i] = pin_dict['transition']
transition_incr[i] = pin_dict['transition_incr']
delay[i] = pin_dict['delay_from_fanout']
transition[-1] = topo['fanout'][0]['transition']
# transition_incr and delay of fanout pin are set to 0
# mask them for training/testing
dataset[num_fanout]['pin_features'].append(pin_features)
dataset[num_fanout]['transition'].append(transition)
dataset[num_fanout]['transition_incr'].append(transition_incr)
dataset[num_fanout]['delay'].append(delay)
class myDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data['pin_features'])
def __getitem__(self, idx):
return self.data['pin_features'][idx], \
self.data['transition'][idx], \
self.data['transition_incr'][idx], \
self.data['delay'][idx]
def build_loaders():
pdk = "asap7"
tag = "no_timing_opt"
raw_data_dir = f"../data/{pdk}"
techlib_dir = os.path.join(raw_data_dir, "techlib")
lib_dir = os.path.join(techlib_dir, 'parsed_lib')
blocks = "aes aes-mbff gcd ibex jpeg uart".split()
test_blocks = {'aes', 'uart'}
train_blocks = set()
for block in blocks:
if not block in test_blocks:
train_blocks.add(block)
rebuild_dataset = 0
dataset_dir = 'dataset'
if rebuild_dataset:
train_dataset = dict()
for block in train_blocks:
print(f'-------- {block} --------')
block_dir = os.path.join(raw_data_dir, tag, block)
data_dir = os.path.join(block_dir, 'parsed')
topos = load_topos(data_dir, block)
build_dataset(topos, train_dataset)
test_dataset = dict()
for block in test_blocks:
print(f'-------- {block} --------')
block_dir = os.path.join(raw_data_dir, tag, block)
data_dir = os.path.join(block_dir, 'parsed')
topos = load_topos(data_dir, block)
build_dataset(topos, test_dataset)
os.makedirs(dataset_dir, exist_ok=True)
with open(os.path.join(dataset_dir, 'train.pkl'), 'wb') as f:
pickle.dump(train_dataset, f)
f.close()
with open(os.path.join(dataset_dir, 'test.pkl'), 'wb') as f:
pickle.dump(test_dataset, f)
f.close()
else:
with open(os.path.join(dataset_dir, 'train.pkl'), 'rb') as f:
train_dataset = pickle.load(f)
f.close()
with open(os.path.join(dataset_dir, 'test.pkl'), 'rb') as f:
test_dataset = pickle.load(f)
f.close()
# build loaders
batch_size = 1024
max_fanout = 32
train_loaders = dict()
for key in train_dataset:
if key > max_fanout:
continue
train_loaders[key] = DataLoader(myDataset(train_dataset[key]), batch_size=batch_size, num_workers=4, shuffle=True)
test_loaders = dict()
for key in test_dataset:
if key > max_fanout:
continue
test_loaders[key] = DataLoader(myDataset(test_dataset[key]), batch_size=batch_size, num_workers=4)
return train_loaders, test_loaders
if __name__ == '__main__':
train_loaders, test_loaders = build_loaders()
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pdb
class MLP(torch.nn.Module):
def __init__(self, *sizes, batchnorm=False, dropout=False):
super().__init__()
fcs = []
for i in range(1, len(sizes)):
fcs.append(torch.nn.Linear(sizes[i - 1], sizes[i]))
if i < len(sizes) - 1:
fcs.append(torch.nn.LeakyReLU(negative_slope=0.2))
if dropout: fcs.append(torch.nn.Dropout(p=0.2))
if batchnorm: fcs.append(torch.nn.BatchNorm1d(sizes[i]))
self.layers = torch.nn.Sequential(*fcs)
def forward(self, x):
return self.layers(x)
class SelfAtt(nn.Module):
def __init__(self, input_size, n_heads, hidden_size_per_head):
super().__init__()
self.n_heads = n_heads
self.input_size = input_size
self.hidden_size_per_head = hidden_size_per_head
self.query = MLP(input_size, n_heads * hidden_size_per_head)
self.key = MLP(input_size, n_heads * hidden_size_per_head)
self.value = MLP(input_size, n_heads * hidden_size_per_head)
self.reduce_heads = MLP(n_heads * hidden_size_per_head, hidden_size_per_head)
# @param x: [#batch, #num_inputs, #n_heads*hidden_size_per_head]
def _transpose(self, x):
x = x.view(x.shape[0], x.shape[1], self.n_heads, self.hidden_size_per_head)
return x.permute(0, 2, 1, 3)
# @param input: [#batch, #num_inputs, #features]
def forward(self, input):
query = self.query(input)
key = self.key(input)
value = self.value(input)
query = self._transpose(query)
key = self._transpose(key)
value = self._transpose(value)
att_scores = torch.matmul(query, key.transpose(-1, -2))
att_scores = att_scores / math.sqrt(self.hidden_size_per_head)
att_probs = nn.Softmax(dim=-1)(att_scores)
context = torch.matmul(att_probs, value)
context = context.permute(0, 2, 1, 3).contiguous()
context = context.view(context.shape[0], context.shape[1], self.n_heads * self.hidden_size_per_head)
output = self.reduce_heads(context)
return output
class NetPred(nn.Module):
def __init__(self, input_size, output_size, n_heads, hidden_size_per_head):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.n_heads = n_heads
self.hidden_size_per_head = hidden_size_per_head
self.att0 = SelfAtt(input_size, n_heads, hidden_size_per_head)
self.att1 = SelfAtt(hidden_size_per_head, n_heads, hidden_size_per_head)
self.att2 = SelfAtt(hidden_size_per_head, n_heads, hidden_size_per_head)
self.delay_pred_mlp = MLP(output_size + hidden_size_per_head, 64, output_size)
self.transition_pred_mlp = MLP(output_size + hidden_size_per_head, 64, output_size)
def forward(self, pin_feat, fanout_transition):
x = self.att0(pin_feat)
x = self.att1(x)
x = self.att2(x)
x = x[:, :-1, :]
num_fanout = x.size(1)
trans = fanout_transition.unsqueeze(1).repeat(1, num_fanout, 1)
x = torch.cat([x, trans], dim=-1)
delay_pred = self.delay_pred_mlp(x)
transition_pred = self.transition_pred_mlp(x)
return delay_pred, transition_pred
\ No newline at end of file
"""
@brief extract all super_edge features
@author lvzhengyang
@date 2024.01.29
"""
""" Intro
extract the net features in json format.
each net contains a fanout and (multiple) fanin(s)
[
{
'net': #net_name,
'fanout': [
{
'pin_full_name': #pin_full_name,
'inst': #inst_name,
'libcell': #libcell_name,
'libcell_pin': #libcell_pin_name,
'dir': #dir,
'caps': #caps,
'loc': #(x, y),
'layer': #layer (maybe one-hot encoding?),
'libcell_drive_strength': #libcell_drive_strength,
'num_fanout': #num_fanout,
'transition': #transition,
'arrival': #arrival_time,
'required_arrival': #required_arrival_time,
}
]
'fanin': [
{
'pin_full_name': #pin_full_name,
'inst': #inst_name,
'libcell': #libcell_name,
'libcell_pin': #libcell_pin_name,
'dir': #dir,
'caps': #caps,
'loc': #(x, y),
'layer': #layer (maybe one-hot encoding?),
'libcell_drive_strength': #libcell_drive_strength,
'num_fanout': #num_fanout,
'transition': #transition,
'arrival': #arrival_time,
'required_arrival': #required_arrival_time,
'delay_from_fanout': #delay,
'transition_incr': #incr,
},
{},
...
]
'layout': {
'RUDY': ,
'cell_density': ,
'macro_region': ,
...
},
'steiner_tree': {
'branches': [
'src': (x, y),
'dst': (x, y),
]
},
},
{
'net': #net_name,
...
},
...
]
"""
""" Suggestions
1. gather the nets with the same degree as a group, so that they can
easily batched and make training/inference more efficently.
"""
""" Work-around flow
1. extract net topo and layout features from DREAMPlace
2. extract transition and delay from .sdf file
3. build a steiner tree generator
"""
import numpy as np
import pickle
import json
import os
from utils import *
import pdb
def parse_libs2(libs, save_dir):
libdata = dict()
pin_cap = dict()
for tag in libs:
liberty_files = libs[tag]
libdata[tag] = dict()
pin_cap[tag] = dict()
for liberty_file in liberty_files:
if liberty_file[-3:] == '.gz':
library = parse_liberty(gzip.open(liberty_file, 'rb').read().decode('utf-8'))
else:
library = parse_liberty(open(liberty_file).read())
# Loop through all cells.
for cell_group in library.get_groups('cell'):
cell_name = cell_group.args[0]
libdata[tag][cell_name] = dict()
pin_cap[tag][cell_name] = dict()
def parse_pin_group(pin_group):
if type(pin_group.args[0]) == str:
pin_name = pin_group.args[0]
else:
pin_name = pin_group.args[0].value
# Access a pin attribute.
rise_capacitance = pin_group['rise_capacitance']
fall_capacitance = pin_group['fall_capacitance']
if rise_capacitance == None:
rise_capacitance = pin_group['max_capacitance']
if fall_capacitance == None:
fall_capacitance = pin_group['max_capacitance']
pin_cap[tag][cell_name][pin_name] = dict()
pin_cap[tag][cell_name][pin_name]['rise_capacitance'] = rise_capacitance
pin_cap[tag][cell_name][pin_name]['fall_capacitance'] = fall_capacitance
timing = pin_group.get_groups('timing')
for tg in timing:
if type(tg['related_pin']) == str:
related_pin = tg['related_pin']
else:
related_pin = tg['related_pin'].value
# NOTE: for sky130hd PDK cell sky130_fd_sc_hd__ebufn_1,
# its pin "Z" has 2 timing lut related to pin "TE_B",
# with attribute "timing_type" as "three_state_disable"
# and "three_state_enable" respectively.
# with "three_state_enable", the rise/fall_transition is empty.
# with "three_state_disable", the rise/fall_transition is valid.
# Here, we just randomly choose one timing lut of this path.
# that is, if this path has been already met during parsing, just skip it.
cell_path_key = pin_name + "/" + related_pin
if cell_path_key in libdata[tag][cell_name]:
continue
def decode_lut(lut):
def decode_(key):
pat = r'\d+\.\d+|\d+' # find all floating numbers or integer numbers
data_str = ""
for ss in lut[key]:
data_str += ss.value
data_str += " "
val_strs = re.findall(pat, data_str)
vals = []
for val in val_strs:
vals.append(float(val))
return np.array(vals)
result = np.concatenate([
decode_('index_1'), decode_('index_2'), decode_('values')
])
return result
cell_rise = tg.get_groups('cell_rise')
cell_fall = tg.get_groups('cell_fall')
rise_transition = tg.get_groups('rise_transition')
fall_transition = tg.get_groups('fall_transition')
valid_cell_rise = True
valid_cell_fall = True
valid_rise_transition = True
valid_fall_transition = True
try:
cell_rise = decode_lut(cell_rise[0])
if cell_rise.size != 63:
cell_rise = np.zeros(63)
valid_cell_rise = False
except:
cell_rise = np.zeros(63)
valid_cell_rise = False
try:
cell_fall = decode_lut(cell_fall[0])
if cell_fall.size != 63:
cell_fall = np.zeros(63)
valid_cell_fall = False
except:
cell_fall = np.zeros(63)
valid_cell_fall = False
try:
rise_transition = decode_lut(rise_transition[0])
if rise_transition.size != 63:
rise_transition = np.zeros(63)
valid_rise_transition = False
except:
rise_transition = np.zeros(63)
valid_rise_transition = False
try:
fall_transition = decode_lut(fall_transition[0])
if fall_transition.size != 63:
fall_transition = np.zeros(63)
valid_fall_transition = False
except:
fall_transition = np.zeros(63)
valid_fall_transition = False
prefix = np.zeros(4*15)
suffix = np.zeros(4*49)
lut_valid = [valid_cell_rise, valid_cell_fall, valid_rise_transition, valid_fall_transition]
lut_data = [cell_rise, cell_fall, rise_transition, fall_transition]
for i in range(4):
prefix[i*15] = lut_valid[i]
prefix[i*15+1 : i*15+15] = lut_data[i][:14]
suffix[i*49 : i*49+49] = lut_data[i][14:63]
libdata[tag][cell_name][cell_path_key] = np.concatenate([prefix, suffix])
# Loop through all pins of the cell.
for pin_group in cell_group.get_groups('pin'):
parse_pin_group(pin_group)
for bundle_group in cell_group.get_groups('bundle'):
for pin_group in bundle_group.get_groups('pin'):
parse_pin_group(pin_group)
print(f'finish reading {liberty_file}')
# pickle and save libdata and pin_cap
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, "lib_data.pkl"), 'wb') as f:
pickle.dump(libdata, f)
f.close()
with open(os.path.join(save_dir, "pin_caps.pkl"), 'wb') as f:
pickle.dump(pin_cap, f)
f.close()
def load_net_data(data_dir, lib_dir):
# net
net_names = np.load(os.path.join(data_dir, 'net_names.npy'))
net2pin_map = np.load(os.path.join(data_dir, 'net2pin_map.npy'), allow_pickle=True)
with open(os.path.join(data_dir, 'net_name2id_map.pkl'), 'rb') as f:
net_name2id_map = pickle.load(f)
f.close()
# pin
pin_names_raw = np.load(os.path.join(data_dir, 'pin_names.npy'))
pin_direct = np.load(os.path.join(data_dir, 'pin_direct.npy'))
pin_offset_x = np.load(os.path.join(data_dir, 'pin_offset_x.npy'))
pin_offset_y = np.load(os.path.join(data_dir, 'pin_offset_y.npy'))
pin2net_map = np.load(os.path.join(data_dir, 'pin2net_map.npy'))
pin2node_map = np.load(os.path.join(data_dir, 'pin2node_map.npy'))
# node (instance & port)
node_names = np.load(os.path.join(data_dir, 'node_names.npy'))
node_orient = np.load(os.path.join(data_dir, 'node_orient.npy'))
node_x = np.load(os.path.join(data_dir, 'node_x.npy'))
node_y = np.load(os.path.join(data_dir, 'node_y.npy'))
node_size_x = np.load(os.path.join(data_dir, 'node_size_x.npy'))
node_size_y = np.load(os.path.join(data_dir, 'node_size_y.npy'))
node2pin_map = np.load(os.path.join(data_dir, 'node2pin_map.npy'), allow_pickle=True)
inst2libcell_map = None
with open(os.path.join(data_dir, 'inst2libcell_map.pkl'), 'rb') as f:
inst2libcell_map = pickle.load(f)
f.close()
inst_time, pin_time = load_sdf_dict(data_dir)
lib_data, pin_caps = load_libdata(lib_dir)
# features from placement
topos = []
num_nets = net_names.size
for ni in range(num_nets):
net_name = net_names[ni].decode()
fanout, fanin = [], []
net_pins = net2pin_map[ni]
num_net_pins = net_pins.size
# NOTE: consider valid nets only
if num_net_pins <= 1:
continue
# NOTE: do NOT consider nets related to ports
is_port_net = False
for pi in range(num_net_pins):
pin_idx = net_pins[pi]
pin_name = pin_names_raw[pin_idx]
node_idx = pin2node_map[pin_idx]
node_name = node_names[node_idx]
if pin_name == node_name:
is_port_net = True
break
else:
inst_name = node_name.decode()
libcell_pin_name = pin_name.decode()
pin_full_name = inst_name + '/' + libcell_pin_name
pin_x = node_x[node_idx] + pin_offset_x[pin_idx]
pin_y = node_y[node_idx] + pin_offset_y[pin_idx]
pin_dict = {
'pin_full_name': pin_full_name,
'inst': inst_name,
'libcell_pin': libcell_pin_name,
'loc': np.array([pin_x, pin_y]),
}
if pin_direct[pin_idx] == b'OUTPUT':
pin_dict['dir'] = 'out'
fanout.append(pin_dict)
else:
pin_dict['dir'] = 'in'
fanin.append(pin_dict)
if is_port_net:
continue
if len(fanout) != 1:
# NOTE: this case may need to be special dealed with?
pdb.set_trace()
topo = {
'net_name': net_name,
'fanin': fanin,
'fanout': fanout,
}
topos.append(topo)
# exclude invalid nets
topos_old = topos
topos = []
for topo in topos_old:
tags = ['fanin', 'fanout']
flag = True
for tag in tags:
for pin_dict in topo[tag]:
if not pin_dict['inst'] in inst2libcell_map:
print(f"Skip {pin_dict['inst']}")
flag = False
break
if not flag:
break
if not flag:
continue
topos.append(topo)
# features from liberty
for topo in topos:
def get_liberty_pin_features(pin_dict):
pin_dict['libcell'] = inst2libcell_map[pin_dict['inst']]
caps = []
caps_dict = pin_caps['fast'][pin_dict['libcell']][pin_dict['libcell_pin']]
caps += [caps_dict['rise_capacitance'], caps_dict['fall_capacitance']]
caps_dict = pin_caps['slow'][pin_dict['libcell']][pin_dict['libcell_pin']]
caps += [caps_dict['rise_capacitance'], caps_dict['fall_capacitance']]
pin_dict['caps'] = caps
pin_dict['num_fanout'] = len(topo['fanin'])
tags = ['fanin', 'fanout']
for tag in tags:
for pin_dict in topo[tag]:
get_liberty_pin_features(pin_dict)
# features from sdf
for topo in topos:
tags = ['fanin', 'fanout']
for tag in tags:
for pin_dict in topo[tag]:
pin_info = pin_time[pin_dict['pin_full_name']]
pin_dict['arrival'] = pin_info.at
pin_dict['required_arrival'] = pin_info.rat
pin_dict['transition'] = pin_info.slew
# delays from fanout2fanin
assert(len(topo['fanout']) == 1)
for pin_dict in topo['fanin']:
delay = pin_dict['arrival'] - topo['fanout'][0]['arrival']
pin_dict['delay_from_fanout'] = delay
pin_dict['transition_incr'] = pin_dict['transition'] - topo['fanout'][0]['transition']
# Drive Strength
# according to paper: https://github.com/The-OpenROAD-Project/asap7_pdk_r1p7/blob/1ff7649bbf423207f6f70293dc1cf630cd477365/docs/mej_paper_asap7.pdf
# drive_strength (high -> low)
# SLVT (4), LVT (3), RVT (2), SRAM (1)
for topo in topos:
tags = ['fanin', 'fanout']
for tag in tags:
for pin_dict in topo[tag]:
cell_name = pin_dict['libcell']
# find driving strength
if cell_name[-3:] == '_SL':
drive_strength = 4
elif cell_name[-2:] == '_L':
drive_strength = 3
elif cell_name[-2:] == '_R':
drive_strength = 2
elif cell_name[-5:] == '_SRAM':
drive_strength = 1
else:
drive_strength = 0
pin_dict['libcell_drive_strength'] = drive_strength
return topos
def to_json_format(topos):
# np.array cannot be written to json
keys = ['loc', 'transition', 'arrival', 'required_arrival', 'delay_from_fanout', 'transition_incr']
tags = ['fanin', 'fanout']
for topo in topos:
for tag in tags:
for pin_dict in topo[tag]:
for key in keys:
if key in pin_dict:
pin_dict[key] = pin_dict[key].tolist()
if __name__ == '__main__':
pdk = "asap7"
tag = "no_timing_opt"
raw_data_dir = f"../data/{pdk}"
techlib_dir = os.path.join(raw_data_dir, "techlib")
lib_dir = os.path.join(techlib_dir, 'parsed_lib')
do_parse_libs = 0
if do_parse_libs:
all_libs = os.listdir(os.path.join(techlib_dir, "lib"))
fast_libs = []
slow_libs = []
for lib in all_libs:
if '.swp' in lib:
continue
if '_FF_' in lib:
fast_libs.append(os.path.join(techlib_dir, "lib", lib))
elif '_SS_' in lib:
slow_libs.append(os.path.join(techlib_dir, "lib", lib))
elif 'fake' in lib or 'FAKE' in lib:
fast_libs.append(os.path.join(techlib_dir, "lib", lib))
slow_libs.append(os.path.join(techlib_dir, "lib", lib))
libs = {}
libs['fast'] = fast_libs
libs['slow'] = slow_libs
save_dir = os.path.join(techlib_dir, "parsed_lib")
parse_libs2(libs, save_dir)
blocks = "aes aes-mbff gcd ibex jpeg uart".split()
for block in blocks:
print(f'-------- {block} --------')
block_dir = os.path.join(raw_data_dir, tag, block)
data_dir = os.path.join(block_dir, 'parsed')
topos = load_net_data(data_dir, lib_dir)
with open(os.path.join(block_dir, 'parsed', f'{block}.topo.dict'), 'wb') as f:
pickle.dump(topos, f)
f.close()
to_json_format(topos)
with open(os.path.join(block_dir, 'parsed', f'{block}.topo.json'), 'w') as f:
json_str = json.dump(topos, f, indent=4)
f.close()
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import time
from torch.utils.tensorboard import SummaryWriter
from models import NetPred
from gather_data import build_loaders
import pdb
def train(train_loaders, test_loaders):
input_dim, output_dim, n_heads, hidden_size_per_head = 9, 4, 8, 64
model = NetPred(input_dim, output_dim, n_heads, hidden_size_per_head)
model.cuda()
writer = SummaryWriter()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
num_fanin_pins_train = len(train_loaders)
num_fanin_pins_test = len(test_loaders)
best_train_loss, best_test_loss = 1e9, 1e9
os.makedirs('weights', exist_ok=True)
start_time = time.time()
for e in range(100000):
epoch_start_time = time.time()
tot_delay_loss, tot_transition_loss = 0, 0
model.train()
optimizer.zero_grad()
for key, loader in train_loaders.items():
for batch_idx, (pin_features, transition, transition_incr, delay) in enumerate(loader):
fanout_transition = transition[:, -1].cuda()
pin_features = pin_features.cuda()
transition_incr = transition_incr[:, :-1, :].cuda()
delay = delay[:, :-1, :].cuda()
# fanout_transition is extremely big for large fanout net
# we exclude super-large fanout net here
fanout_transition.log_()
delay_pred, transition_incr_pred = model(pin_features, fanout_transition)
# delay is too small (1e-2 ~ 1e0)
delay.add_(1e-6).mul_(1e2).log_()
delay_loss = F.mse_loss(delay, delay_pred)
# transition_incr is too small (1e-3 ~ 1e0)
transition_incr.add_(1e-6).mul_(1e3).log_()
transition_loss = F.mse_loss(transition_incr, transition_incr_pred)
# tot_delay_loss += delay_loss.item() / num_fanin_pins_train
# tot_transition_loss += transition_loss.item() / num_fanin_pins_train
tot_delay_loss += delay_loss.item()
tot_transition_loss += transition_loss.item()
(delay_loss + transition_loss).backward()
optimizer.step()
tot_delay_loss /= num_fanin_pins_train
tot_transition_loss /= num_fanin_pins_train
writer.add_scalar('delay_loss/train: ', tot_delay_loss)
writer.add_scalar('transition_loss/train: ', tot_transition_loss)
if e >= 20:
tot_loss = tot_delay_loss + tot_transition_loss
if tot_loss < best_train_loss:
best_train_loss = tot_loss
save_path = os.path.join('weights', 'e-%d-train_delay_loss-%.3f-train_transition_loss-%.3f.pt' %(e + 1, tot_delay_loss, tot_transition_loss))
torch.save(model.state_dict(), save_path)
print(f'NOTE: save model to {save_path}')
if (e + 1) % 20 == 0:
# evaluation
model.eval()
with torch.no_grad():
tot_delay_loss_test, tot_transition_loss_test = 0, 0
for key, loader in test_loaders.items():
for batch_idx, (pin_features, transition, transition_incr, delay) in enumerate(loader):
fanout_transition = transition[:, -1].cuda()
pin_features = pin_features.cuda()
transition_incr = transition_incr[:, :-1, :].cuda()
delay = delay[:, :-1, :].cuda()
# fanout_transition is extremely big for large fanout net
# we exclude super-large fanout net here
fanout_transition.log_()
delay_pred, transition_incr_pred = model(pin_features, fanout_transition)
# delay is too small (1e-2 ~ 1e0)
delay.add_(1e-6).mul_(1e2).log_()
delay_loss = F.mse_loss(delay, delay_pred)
# transition_incr is too small (1e-3 ~ 1e0)
transition_incr.add_(1e-6).mul_(1e3).log_()
transition_loss = F.mse_loss(transition_incr, transition_incr_pred)
# tot_delay_loss_test += delay_loss.item() / num_fanin_pins_test
# tot_transition_loss_test += transition_loss.item() / num_fanin_pins_test
tot_delay_loss_test += delay_loss.item()
tot_transition_loss_test += transition_loss.item()
tot_delay_loss_test /= num_fanin_pins_test
tot_transition_loss_test /= num_fanin_pins_test
writer.add_scalar('delay_loss/test: ', tot_delay_loss_test)
writer.add_scalar('transition_loss/test: ', tot_transition_loss_test)
print("epoch: {}\n (Train)\t tot_loss: {:.6f}\t delay_loss: {:.6f}\t transition_loss: {:.6f}\n (Test) \t tot_loss: {:.6f}\t delay_loss: {:.6f}\t transition_loss: {:.6f}".format(
e + 1,
tot_delay_loss + tot_transition_loss, tot_delay_loss, tot_transition_loss,
tot_delay_loss_test + tot_transition_loss_test, tot_delay_loss_test, tot_transition_loss_test
))
tot_loss = tot_delay_loss_test + tot_transition_loss_test
if tot_loss < best_test_loss:
best_test_loss = tot_loss
save_path = os.path.join('weights', 'e-%d-test_delay_loss-%.3f-test_transition_loss-%.3f.pt' %(e + 1, tot_delay_loss_test, tot_transition_loss_test))
torch.save(model.state_dict(), save_path)
print(f'NOTE: save model to {save_path}')
if __name__ == '__main__':
train_loaders, test_loaders = build_loaders()
train(train_loaders, test_loaders)
import numpy as np
import torch
import dgl
import os
import pickle
import re
import pdb
import gzip
import queue
from collections import defaultdict
from liberty.parser import parse_liberty
from liberty.types import *
class Delay:
def __init__(self, src, dst, val):
self.src = src
self.dst = dst
self.value = val
class CellDelay:
def __init__(self):
self.celltype = None
self.inst = None
self.interconnect = []
self.iopath = []
self.isFF = False # deprecate this
class ArrivalTime:
def __init__(self):
self.at = None
self.slew = None
self.rat = None
def parse_sdf(sdf_file, block_name, save_dir):
lines = None
with open(sdf_file, 'r') as f:
lines = f.readlines()
f.close()
# TODO: deal with FF
# e.g. inst '_73818_' of celltype 'sky130_fd_sc_hd__dfxtp_1' in block 'aes128'
inst_time = dict()
# get all insts
for line in lines:
words = line.split()
if words[0] == "(INSTANCE)":
inst_time[block_name] = CellDelay()
elif words[0] == "(INSTANCE":
inst_time[words[1][:-1]] = CellDelay()
print(f"num insts: {len(inst_time)}")
pin_time = dict()
invalid_cnt = 0
tot_cnt = 0
num_lines = len(lines)
for li in range(num_lines):
line = lines[li]
# parse CELL field
words = line.split()
if words[0] == "(CELL":
cell_type = lines[li + 1].split()[-1][1:-2]
inst = lines[li + 2].split()[-1][:-1]
if len(lines[li + 2].split()) == 1:
inst = block_name
inst_time[inst].celltype = cell_type
inst_time[inst].inst = inst
li += 5
while True:
line = lines[li]
li += 1
words = line.split()
if words[0] == ')':
break
elif words[0] == "(INTERCONNECT":
src = words[1]
dst = words[2]
if src == "VGND" or src == 'VPWR' or dst == "VGND" or dst == "VPWR":
continue
else:
numbers = words[-2] + words[-1]
# pat = r'\d+\.\d+|\d+' # find all floating numbers
pat = r'\d+\.\d+' # find all floating numbers
vals_str = re.findall(pat, numbers)
vals = []
for vs in vals_str:
vals.append(float(vs))
tot_cnt += 1
if len(vals) != 4:
# I don't know why there are some lines has only 2 values
# open sta shows that max/min rise/fall values are the same
if len(vals) != 2:
invalid_cnt += 1
print(line)
continue
vals.append(vals[0])
vals.append(vals[1])
# inst_time[inst].interconnect.append(
# Delay(src, dst, np.array(vals, dtype=np.double))
# )
# make it like the DAC22's paper
vals_np = np.zeros(4, dtype=np.double)
vals_np[0] = vals[0]
vals_np[1] = vals[2]
vals_np[2] = vals[1]
vals_np[3] = vals[3]
inst_time[inst].interconnect.append(
Delay(src, dst, vals_np)
)
elif words[0] == "(IOPATH":
src = words[1]
dst = words[2]
if src == "VGND" or src == 'VPWR' or dst == "VGND" or dst == "VPWR":
continue
else:
numbers = words[-2] + words[-1]
# pat = r'\d+\.\d+|\d+' # find all floating numbers
pat = r'\d+\.\d+' # find all floating numbers
vals_str = re.findall(pat, numbers)
vals = []
for vs in vals_str:
vals.append(float(vs))
tot_cnt += 1
if len(vals) != 4:
if len(vals) != 2:
invalid_cnt += 1
print(line)
continue
vals.append(vals[0])
vals.append(vals[1])
# inst_time[inst].iopath.append(
# Delay(src, dst, np.array(vals, dtype=np.double))
# )
# make it like the DAC22's paper
vals_np = np.zeros(4, dtype=np.double)
vals_np[0] = vals[0]
vals_np[1] = vals[2]
vals_np[2] = vals[1]
vals_np[3] = vals[3]
inst_time[inst].iopath.append(
Delay(src, dst, vals_np)
)
# parse ARRIVALTIMES field
elif words[0] == "(ARRIVALTIMES":
li += 1
while True:
line = lines[li]
li += 1
words = line.split()
if words[0] == ')':
break
elif words[0] == "(AT":
pin_name = words[1]
if pin_name == "VGND" or pin_name == "VPER":
continue
numbers = words[-2] + words[-1]
# pat = r'\d+\.\d+|\d+' # find all floating numbers
pat = r'\d+\.\d+' # find all floating numbers
vals_str = re.findall(pat, numbers)
vals = []
for vs in vals_str:
vals.append(float(vs))
assert(len(vals) == 4)
# make it like the DAC22's paper
vals_np = np.zeros(4, dtype=np.double)
vals_np[0] = vals[0]
vals_np[1] = vals[2]
vals_np[2] = vals[1]
vals_np[3] = vals[3]
if not pin_name in pin_time:
pin_time[pin_name] = ArrivalTime()
# pin_time[pin_name].at = np.array(vals, dtype=np.double)
pin_time[pin_name].at = vals_np
elif words[0] == "(SLEW":
pin_name = words[1]
if pin_name == "VGND" or pin_name == "VPER":
continue
numbers = words[-2] + words[-1]
pat = r'\d+\.\d+|\d+' # find all floating numbers
vals_str = re.findall(pat, numbers)
vals = []
for vs in vals_str:
vals.append(float(vs))
assert(len(vals) == 4)
# make it like the DAC22's paper
vals_np = np.zeros(4, dtype=np.double)
vals_np[0] = vals[0]
vals_np[1] = vals[2]
vals_np[2] = vals[1]
vals_np[3] = vals[3]
if not pin_name in pin_time:
pin_time[pin_name] = ArrivalTime()
# pin_time[pin_name].slew = np.array(vals, dtype=np.double)
pin_time[pin_name].slew = vals_np
elif words[0] == "(RAT":
pin_name = words[1]
if pin_name == "VGND" or pin_name == "VPER":
continue
numbers = words[-2] + words[-1]
pat = r'\d+\.\d+|\d+' # find all floating numbers
vals_str = re.findall(pat, numbers)
vals = []
for vs in vals_str:
vals.append(float(vs))
assert(len(vals) == 4)
# make it like the DAC22's paper
vals_np = np.zeros(4, dtype=np.double)
vals_np[0] = vals[0]
vals_np[1] = vals[2]
vals_np[2] = vals[1]
vals_np[3] = vals[3]
if not pin_name in pin_time:
pin_time[pin_name] = ArrivalTime()
# pin_time[pin_name].rat = np.array(vals, dtype=np.double)
pin_time[pin_name].rat = vals_np
print(f'invalid_cnt / tot_cnt: {invalid_cnt} / {tot_cnt}, {float(invalid_cnt / tot_cnt) * 100}%')
# deal with FFs
for li in range(num_lines):
line = lines[li]
# parse CELL field
words = line.split()
if words[0] == "(CELL":
cell_type = lines[li + 1].split()[-1][1:-2]
inst = lines[li + 2].split()[-1][:-1]
if len(lines[li + 2].split()) == 1:
inst = block_name
elif words[0] == "(TIMINGCHECK":
inst_time[inst].isFF = True
with open(os.path.join(save_dir, "inst_time.dict.pkl"), 'wb') as f:
pickle.dump(inst_time, f)
f.close()
with open(os.path.join(save_dir, "pin_time.dict.pkl"), 'wb') as f:
pickle.dump(pin_time, f)
f.close()
def load_sdf_dict(save_dir):
inst_time = None
pin_time = None
with open(os.path.join(save_dir, "inst_time.dict.pkl"), 'rb') as f:
inst_time = pickle.load(f)
f.close()
with open(os.path.join(save_dir, "pin_time.dict.pkl"), 'rb') as f:
pin_time = pickle.load(f)
f.close()
return inst_time, pin_time
def parse_libs(libs, save_dir):
libdata = dict()
pin_cap = dict()
for tag in libs:
liberty_files = libs[tag]
libdata[tag] = dict()
pin_cap[tag] = dict()
for liberty_file in liberty_files:
if liberty_file[-3:] == '.gz':
library = parse_liberty(gzip.open(liberty_file, 'rb').read().decode('utf-8'))
else:
library = parse_liberty(open(liberty_file).read())
# Loop through all cells.
for cell_group in library.get_groups('cell'):
cell_name = cell_group.args[0]
libdata[tag][cell_name] = dict()
pin_cap[tag][cell_name] = dict()
def parse_pin_group(pin_group):
if type(pin_group.args[0]) == str:
pin_name = pin_group.args[0]
else:
pin_name = pin_group.args[0].value
# Access a pin attribute.
rise_capacitance = pin_group['rise_capacitance']
fall_capacitance = pin_group['fall_capacitance']
if rise_capacitance == None:
rise_capacitance = 0.0
if fall_capacitance == None:
fall_capacitance = 0.0
pin_cap[tag][cell_name][pin_name] = dict()
pin_cap[tag][cell_name][pin_name]['rise_capacitance'] = rise_capacitance
pin_cap[tag][cell_name][pin_name]['fall_capacitance'] = fall_capacitance
timing = pin_group.get_groups('timing')
for tg in timing:
if type(tg['related_pin']) == str:
related_pin = tg['related_pin']
else:
related_pin = tg['related_pin'].value
# NOTE: for sky130hd PDK cell sky130_fd_sc_hd__ebufn_1,
# its pin "Z" has 2 timing lut related to pin "TE_B",
# with attribute "timing_type" as "three_state_disable"
# and "three_state_enable" respectively.
# with "three_state_enable", the rise/fall_transition is empty.
# with "three_state_disable", the rise/fall_transition is valid.
# Here, we just randomly choose one timing lut of this path.
# that is, if this path has been already met during parsing, just skip it.
cell_path_key = pin_name + "/" + related_pin
if cell_path_key in libdata[tag][cell_name]:
continue
def decode_lut(lut):
def decode_(key):
pat = r'\d+\.\d+|\d+' # find all floating numbers or integer numbers
data_str = ""
for ss in lut[key]:
data_str += ss.value
data_str += " "
val_strs = re.findall(pat, data_str)
vals = []
for val in val_strs:
vals.append(float(val))
return np.array(vals)
result = np.concatenate([
decode_('index_1'), decode_('index_2'), decode_('values')
])
return result
cell_rise = tg.get_groups('cell_rise')
cell_fall = tg.get_groups('cell_fall')
rise_transition = tg.get_groups('rise_transition')
fall_transition = tg.get_groups('fall_transition')
valid_cell_rise = True
valid_cell_fall = True
valid_rise_transition = True
valid_fall_transition = True
try:
cell_rise = decode_lut(cell_rise[0])
if cell_rise.size != 63:
cell_rise = np.zeros(63)
valid_cell_rise = False
except:
cell_rise = np.zeros(63)
valid_cell_rise = False
try:
cell_fall = decode_lut(cell_fall[0])
if cell_fall.size != 63:
cell_fall = np.zeros(63)
valid_cell_fall = False
except:
cell_fall = np.zeros(63)
valid_cell_fall = False
try:
rise_transition = decode_lut(rise_transition[0])
if rise_transition.size != 63:
rise_transition = np.zeros(63)
valid_rise_transition = False
except:
rise_transition = np.zeros(63)
valid_rise_transition = False
try:
fall_transition = decode_lut(fall_transition[0])
if fall_transition.size != 63:
fall_transition = np.zeros(63)
valid_fall_transition = False
except:
fall_transition = np.zeros(63)
valid_fall_transition = False
prefix = np.zeros(4*15)
suffix = np.zeros(4*49)
lut_valid = [valid_cell_rise, valid_cell_fall, valid_rise_transition, valid_fall_transition]
lut_data = [cell_rise, cell_fall, rise_transition, fall_transition]
for i in range(4):
prefix[i*15] = lut_valid[i]
prefix[i*15+1 : i*15+15] = lut_data[i][:14]
suffix[i*49 : i*49+49] = lut_data[i][14:63]
libdata[tag][cell_name][cell_path_key] = np.concatenate([prefix, suffix])
# Loop through all pins of the cell.
for pin_group in cell_group.get_groups('pin'):
parse_pin_group(pin_group)
for bundle_group in cell_group.get_groups('bundle'):
for pin_group in bundle_group.get_groups('pin'):
parse_pin_group(pin_group)
print(f'finish reading {liberty_file}')
# pickle and save libdata and pin_cap
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, "lib_data.pkl"), 'wb') as f:
pickle.dump(libdata, f)
f.close()
with open(os.path.join(save_dir, "pin_caps.pkl"), 'wb') as f:
pickle.dump(pin_cap, f)
f.close()
def load_libdata(save_dir):
with open(os.path.join(save_dir, "lib_data.pkl"), 'rb') as f:
libdata = pickle.load(f)
f.close()
with open(os.path.join(save_dir, "pin_caps.pkl"), 'rb') as f:
pincaps = pickle.load(f)
f.close()
return libdata, pincaps
def parse_inst2libcell_map(def_file, save_dir):
with open(def_file, 'r') as f:
lines = f.readlines()
f.close()
inst2libcell_map = dict()
for li in range(len(lines)):
line = lines[li]
words = line.split()
if len(words) == 0:
continue
if words[0] == 'COMPONENTS':
li += 1
while True:
line = lines[li]
li += 1
words = line.split()
if words[0] == 'END' and words[1] == 'COMPONENTS':
break
# valid_lines
inst = words[1]
libcell = words[2]
inst2libcell_map[inst] = libcell
break
with open(os.path.join(save_dir, 'inst2libcell_map.pkl'), 'wb') as f:
pickle.dump(inst2libcell_map, f)
f.close()
timing_arc_types = [
'clk2q', 'in2out', # cell arc
'inport2d', 'inport2clk', 'inport2in', # src = input port
'q2outport', 'out2outport', # dst = output port
'inport2outport', # feed-through
'q2in', 'out2in', 'out2d', 'q2d', # wire arc
]
timing_arc_types2id_map = dict()
tmp_cnt = 0
for arc_type in timing_arc_types:
timing_arc_types2id_map[arc_type] = tmp_cnt
tmp_cnt += 1
class TimingArc:
# src, dst: pin full name
def __init__(self, src, dst):
self.src = src
self.dst = dst
self.type = None
self.buffers = [] # an ordered sequence of buffer, src->dst, (buf_in, buf_out)
self.delay = None # delays from src->dst
class TimingGraph:
def __init__(self, inst2libcell_map):
self.inst2libcell_map = inst2libcell_map
self.all_arcs = []
self.arc_groups = dict()
self.init()
def init(self):
for arc_type in timing_arc_types:
self.arc_groups[arc_type] = []
def _get_arc_type(self, arc):
_src = arc.src.split('/')
_dst = arc.dst.split('/')
if len(_src) == 1 and len(_dst) == 1: # feed through
return 'inport2outport'
elif len(_src) == 1 and len(_dst) > 1: # from input port
dst_inst = _dst[0]
dst_libcell = self.inst2libcell_map[dst_inst]
if 'FF' in dst_libcell or 'LL' in dst_libcell: # LL: Latch
dst_pin_raw = _dst[1]
if 'D' in dst_pin_raw:
return 'inport2d'
else:
return 'inport2clk'
else:
return 'inport2in'
elif len(_src) > 1 and len(_dst) == 1: # to output port
src_inst = _src[0]
if src_inst == 'ethmac_1634.DREAMPlace.Shape0': # debug
pdb.set_trace()
src_libcell = self.inst2libcell_map[src_inst]
if 'FF' in src_libcell or 'LL' in src_libcell:
return 'q2outport'
else:
return 'out2outport'
else:
src_inst = _src[0]
dst_inst = _dst[0]
src_libcell = self.inst2libcell_map[src_inst]
dst_libcell = self.inst2libcell_map[dst_inst]
if src_inst == dst_inst:
if 'FF' in src_libcell or 'LL' in dst_libcell: # cell arc
return 'clk2q'
else:
return 'in2out'
else: # wire arc
if ('FF' in src_libcell and 'FF' in dst_libcell) or ('LL' in src_libcell and 'LL' in dst_libcell):
return 'q2d'
elif 'FF' in src_libcell or 'LL' in src_libcell:
return 'q2in'
elif 'FF' in dst_libcell or 'LL' in dst_libcell:
return 'out2d'
else:
return 'out2in'
def add_arc(self, arc):
self.all_arcs.append(arc)
# add arc into corresponding groups
arc_type = self._get_arc_type(arc)
arc.type = arc_type
self.arc_groups[arc_type].append(arc)
def get_arc_group(self, arc_type):
return self.arc_groups[arc_type]
def dfs(pin, target, place_pins, src_pin2path_map, first_call=False):
if pin == target:
return [], True
if not first_call and pin in place_pins:
# stop searching
return [], False
if not pin in src_pin2path_map:
return [], False
flag = False
pin_paths = src_pin2path_map[pin]
for pin_path in pin_paths:
chain, flag = dfs(pin_path.dst, target, place_pins, src_pin2path_map)
if flag:
chain.append(pin_path)
return chain, True
# if not flag
return [], False
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment