Commit f4db720e by lvzhengyang

add dataset builder

parent da9db63a
......@@ -28,6 +28,25 @@ yosys = "yosys"
The generated CDFG is copied to `results/<case>/cdfg` by default (or the `output_dir` you set).
### Build Dataset (CDFG + Waveforms)
Configure waveform parsing in `cases/<case>/config.toml` under `[waveform]` and run:
```bash
./main.py build_dataset --case eth_fifo
```
By default the dataset is saved to `results/<case>/dataset/graphs.npz` with metadata in `results/<case>/dataset/meta.json`.
### View CDFG / Dataset
```bash
./main.py view_cdfg --case eth_fifo
./main.py view_dataset --case eth_fifo
```
The commands print a summary to the terminal and save figures under `results/<case>/`.
---
## Legacy RTL2CDFG example
......
......@@ -10,3 +10,22 @@ divide = false
yosys = "yosys"
# Optional override for output location. Defaults to results/<case>/cdfg.
# output_dir = "results/eth_fifo/cdfg"
[waveform]
# Glob is relative to cases/<case>/.
vcd_glob = "sim/*/*/waveform.vcd"
# Signals are relative to the scope below.
scope = "tb_eth_fifo.u_dut"
clock = "clk"
reset = "reset"
reset_active = "high"
# Optional: posedge or negedge.
clock_edge = "posedge"
# Optional: limit number of sampled cycles per trace.
# max_cycles = 200
[dataset]
# Optional override. Defaults to results/<case>/dataset.
# output_dir = "results/eth_fifo/dataset"
# Optional override for CDFG dot if needed.
# cdfg_dot = "results/eth_fifo/cdfg/eth_fifo_ast_clean_cdfg.dot"
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Optional, Callable, List
import os.path as osp
import numpy as np
import torch
import shutil
import os
import copy
import random
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from utils.data_utils import read_npz_file
from parser_func import *
class NpzParser():
'''
Parse the npz file into an inmemory torch_geometric.data.Data object
'''
def __init__(self, data_dir, circuit_path, label_path, \
random_shuffle=False, trainval_split=0.9):
self.data_dir = data_dir
self.dataset = self.inmemory_dataset(data_dir, circuit_path, label_path)
self.trainval_split = trainval_split
self.random_shuffle = random_shuffle
# if random_shuffle:
# perm = torch.randperm(len(dataset))
# dataset = dataset[perm]
# self.train_dataset = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
# self.val_dataset = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
def get_dataset(self, split_with_design=False):
data_len = len(self.dataset)
if not split_with_design:
if self.random_shuffle:
perm = torch.randperm(len(self.dataset))
self.dataset = self.dataset[perm]
training_cutoff = int(data_len * self.trainval_split)
train_dataset = self.dataset[:training_cutoff]
val_dataset = self.dataset[training_cutoff:]
else:
module_name_set = set()
for i in range(len(self.dataset)):
module_name_set.add(self.dataset[i].name.split('_trace')[0])
module_name_list = sorted(module_name_set)
random.seed(666) # the same shuffle each time
random.shuffle(module_name_list)
select_modules = module_name_list[:int(len(module_name_set) * self.trainval_split)]
sel_mask = []
if self.random_shuffle:
perm = torch.randperm(len(self.dataset))
self.dataset = self.dataset[perm]
for i in range(len(self.dataset)):
module_name = self.dataset[i].name.split('_trace')[0]
if module_name in select_modules:
sel_mask.append(True)
else:
sel_mask.append(False)
train_dataset = self.dataset[np.array(sel_mask)]
val_dataset = self.dataset[~np.array(sel_mask)]
return train_dataset, val_dataset
class inmemory_dataset(InMemoryDataset):
def __init__(self, root, graph_path, label_path, transform=None, pre_transform=None, pre_filter=None):
self.name = 'npz_inmm_dataset'
self.root = root
self.graph_path = graph_path
self.label_path = label_path
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_dir(self):
return self.root
@property
def processed_dir(self):
name = 'inmemory-' + os.path.basename(self.graph_path).split('.')[0]
return osp.join(self.root, name)
@property
def raw_file_names(self) -> List[str]:
return [self.graph_path, self.label_path]
@property
def processed_file_names(self) -> str:
return ['data.pt']
def download(self):
pass
def process(self):
data_list = []
designs = read_npz_file(self.graph_path)['designs'].item()
labels = read_npz_file(self.label_path)['labels'].item()
for design_idx, design_name in enumerate(designs):
print('Parse design: {}, {:} / {:} = {:.2f}%'.format(design_name, design_idx, len(designs), design_idx / len(designs) * 100))
x = designs[design_name]['x'] # x[i][0], node_id; x[i][1], node_type; x[i][2], node_width
edge_index = designs[design_name]['edge_index']
edge_type = designs[design_name]['edge_type']
sim_res = designs[design_name]['sim_res']
has_sim_res = designs[design_name]['has_sim_res']
power = designs[design_name]['power']
slack = designs[design_name]['slack']
area = designs[design_name]['area']
if power == None or slack == None or area == None or abs(slack) < 1e-5 or abs(area) < 1e-5 or abs(power[0]) < 1e-5:
continue
y = labels[design_name]['y']
# added by XXXX-5 2024/09/15
# using each simulation result as a separate data
for trace_id, trace in enumerate(sim_res):
graph = parse_pyg_mlpgate(
x=x, edge_index=edge_index, edge_type=edge_type, sim_res=trace, has_sim_res=has_sim_res, y=y,
power=power[trace_id], slack=slack, area=area
)
graph.name = design_name + '_trace' + str(trace_id)
data_list.append(graph)
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
print('[INFO] Inmemory dataset save: ', self.processed_paths[0])
print('Total Designs: {:}'.format(len(data_list)))
def __repr__(self) -> str:
return f'{self.name}({len(self)})'
\ No newline at end of file
#!/usr/bin/env python3
import argparse
import json
import re
import shutil
import subprocess
import sys
......@@ -22,6 +24,69 @@ def _load_case_config(case_name: str) -> dict:
return tomllib.load(f)
def _import_dataset_builder():
try:
import dataset_builder
except Exception as exc:
raise RuntimeError(
"Failed to import dataset_builder. Ensure numpy is installed for dataset operations."
) from exc
return dataset_builder
def _parse_dot_graph(dot_path: Path):
node_attrs = {}
node_order = []
edges = []
def add_node(name: str, attrs=None) -> None:
if name not in node_attrs:
node_order.append(name)
node_attrs[name] = attrs or {}
elif attrs:
node_attrs[name].update(attrs)
def parse_attr_map(attr_text: str):
attrs = {}
for part in attr_text.split(","):
part = part.strip()
if not part or "=" not in part:
continue
key, val = part.split("=", 1)
attrs[key.strip()] = val.strip().strip("\"")
return attrs
with dot_path.open("r") as f:
for raw in f:
line = raw.strip()
if not line or line.startswith("digraph") or line in {"{", "}"}:
continue
if "->" in line:
m = re.search(r'\"?(.+?)\"?\\s*->\\s*\"?(.+?)\"?\\s*\\[label=\"(.*?)\"\\]', line)
if not m:
continue
src, dst, label = m.group(1), m.group(2), m.group(3)
src = src.strip().strip("\"")
dst = dst.strip().strip("\"")
add_node(src)
add_node(dst)
edges.append((src, dst, label))
continue
if "[" in line and "]" in line:
name_part, attr_part = line.split("[", 1)
name = name_part.strip().strip(";").strip().strip("\"")
attr_text = attr_part.split("]", 1)[0]
attrs = parse_attr_map(attr_text)
add_node(name, attrs)
else:
name = line.strip().strip(";").strip().strip("\"")
if name:
add_node(name)
return node_order, node_attrs, edges
def _resolve_rtl_files(case_dir: Path, rtl_files: list[str]) -> list[str]:
resolved = []
missing = []
......@@ -40,6 +105,46 @@ def _resolve_rtl_files(case_dir: Path, rtl_files: list[str]) -> list[str]:
return resolved
def _resolve_cdfg_dir(case_name: str, config: dict) -> Path:
cdfg_cfg = config.get("cdfg", {})
output_dir_cfg = cdfg_cfg.get("output_dir")
if output_dir_cfg:
output_dir = Path(output_dir_cfg)
if not output_dir.is_absolute():
output_dir = (REPO_ROOT / output_dir).resolve()
else:
output_dir = (REPO_ROOT / "results" / case_name / "cdfg").resolve()
return output_dir
def _select_cdfg_dot(case_name: str, config: dict, cdfg_dir: Path) -> Path:
dataset_cfg = config.get("dataset", {})
cdfg_cfg = config.get("cdfg", {})
dot_override = dataset_cfg.get("cdfg_dot") or cdfg_cfg.get("dot")
if dot_override:
dot_path = Path(dot_override)
if not dot_path.is_absolute():
dot_path = (REPO_ROOT / dot_path).resolve()
return dot_path
dot_files = sorted(cdfg_dir.glob("*.dot"))
if not dot_files:
raise FileNotFoundError(f"No .dot files found in {cdfg_dir}")
return dot_files[0]
def _render_dot(dot_path: Path, output_png: Path) -> bool:
dot_bin = shutil.which("dot")
if not dot_bin:
print("[view] Graphviz 'dot' not found. Skipping PNG render.")
return False
output_png.parent.mkdir(parents=True, exist_ok=True)
result = subprocess.run([dot_bin, "-Tpng", str(dot_path), "-o", str(output_png)])
if result.returncode != 0:
print("[view] dot failed to render PNG.")
return False
return True
def _build_cdfg(case_name: str) -> int:
case_dir = REPO_ROOT / "cases" / case_name
if not case_dir.exists():
......@@ -84,13 +189,7 @@ def _build_cdfg(case_name: str) -> int:
if result.returncode != 0:
return result.returncode
output_dir_cfg = cdfg_cfg.get("output_dir")
if output_dir_cfg:
output_dir = Path(output_dir_cfg)
if not output_dir.is_absolute():
output_dir = (REPO_ROOT / output_dir).resolve()
else:
output_dir = (REPO_ROOT / "results" / case_name / "cdfg").resolve()
output_dir = _resolve_cdfg_dir(case_name, config)
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.parent.mkdir(parents=True, exist_ok=True)
......@@ -100,6 +199,98 @@ def _build_cdfg(case_name: str) -> int:
return 0
def _view_cdfg(case_name: str) -> int:
config = _load_case_config(case_name)
cdfg_dir = _resolve_cdfg_dir(case_name, config)
dot_path = _select_cdfg_dot(case_name, config, cdfg_dir)
node_order, node_attrs, edges = _parse_dot_graph(dot_path)
color_counts = {}
for name in node_order:
color = node_attrs.get(name, {}).get("color", "unknown")
color_counts[color] = color_counts.get(color, 0) + 1
edge_label_counts = {}
for _, _, label in edges:
edge_label_counts[label] = edge_label_counts.get(label, 0) + 1
print("[view_cdfg] Summary")
print(f"- dot: {dot_path}")
print(f"- nodes: {len(node_order)}")
print(f"- edges: {len(edges)}")
print(f"- edge labels: {len(edge_label_counts)}")
print("[view_cdfg] Node colors")
for color, count in sorted(color_counts.items(), key=lambda x: (-x[1], x[0])):
print(f"- {color}: {count}")
print("[view_cdfg] Top edge labels")
for label, count in sorted(edge_label_counts.items(), key=lambda x: (-x[1], x[0]))[:10]:
print(f"- {label}: {count}")
output_png = cdfg_dir / "view_cdfg.png"
if _render_dot(dot_path, output_png):
print(f"[view_cdfg] Figure saved to {output_png}")
return 0
def _view_dataset(case_name: str) -> int:
config = _load_case_config(case_name)
dataset_cfg = config.get("dataset", {})
output_dir_cfg = dataset_cfg.get("output_dir")
if output_dir_cfg:
dataset_dir = Path(output_dir_cfg)
if not dataset_dir.is_absolute():
dataset_dir = (REPO_ROOT / dataset_dir).resolve()
else:
dataset_dir = (REPO_ROOT / "results" / case_name / "dataset").resolve()
graphs_npz = dataset_dir / "graphs.npz"
meta_json = dataset_dir / "meta.json"
print("[view_dataset] Summary")
print(f"- dataset dir: {dataset_dir}")
if meta_json.exists():
with meta_json.open("r") as f:
meta = json.load(f)
print(f"- case: {meta.get('case')}")
print(f"- cdfg_dot: {meta.get('cdfg_dot')}")
print(f"- traces: {len(meta.get('vcd_files', []))}")
print(f"- nodes (meta): {len(meta.get('node_names', []))}")
else:
print("- meta.json: not found")
try:
import numpy as np # noqa: F401
except Exception:
print("[view_dataset] numpy not available; skipping npz parsing.")
else:
if graphs_npz.exists():
with np.load(graphs_npz, allow_pickle=True) as data:
designs = data["designs"].item()
print(f"- designs: {len(designs)}")
sample_name = next(iter(designs))
sample = designs[sample_name]
x = sample.get("x")
edge_index = sample.get("edge_index")
sim_res = sample.get("sim_res")
print(f"- sample design: {sample_name}")
print(f"- x shape: {getattr(x, 'shape', None)}")
print(f"- edge_index shape: {getattr(edge_index, 'shape', None)}")
print(f"- traces: {len(sim_res) if sim_res is not None else 0}")
if sim_res:
print(f"- trace[0] shape: {getattr(sim_res[0], 'shape', None)}")
else:
print(f"- graphs.npz: not found at {graphs_npz}")
cdfg_dir = _resolve_cdfg_dir(case_name, config)
dot_path = _select_cdfg_dot(case_name, config, cdfg_dir)
graph_png = dataset_dir / "view_dataset_graph.png"
if _render_dot(dot_path, graph_png):
print(f"[view_dataset] Graph figure saved to {graph_png}")
return 0
def main() -> int:
parser = argparse.ArgumentParser(description="CDFG utilities")
subparsers = parser.add_subparsers(dest="command", required=True)
......@@ -107,10 +298,29 @@ def main() -> int:
build_parser = subparsers.add_parser("build_cdfg", help="Generate CDFG from RTL")
build_parser.add_argument("--case", required=True, help="Case name under cases/")
dataset_parser = subparsers.add_parser("build_dataset", help="Generate dataset from CDFG + waveforms")
dataset_parser.add_argument("--case", required=True, help="Case name under cases/")
view_cdfg_parser = subparsers.add_parser("view_cdfg", help="Summarize and render CDFG")
view_cdfg_parser.add_argument("--case", required=True, help="Case name under cases/")
view_dataset_parser = subparsers.add_parser("view_dataset", help="Summarize and render dataset")
view_dataset_parser.add_argument("--case", required=True, help="Case name under cases/")
args = parser.parse_args()
if args.command == "build_cdfg":
return _build_cdfg(args.case)
if args.command == "build_dataset":
dataset_builder = _import_dataset_builder()
config = _load_case_config(args.case)
output_npz = dataset_builder.build_dataset(args.case, REPO_ROOT, config)
print(f"[build_dataset] Dataset saved to {output_npz}")
return 0
if args.command == "view_cdfg":
return _view_cdfg(args.case)
if args.command == "view_dataset":
return _view_dataset(args.case)
parser.print_help()
return 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment