Commit f4db720e by lvzhengyang

add dataset builder

parent da9db63a
......@@ -28,6 +28,25 @@ yosys = "yosys"
The generated CDFG is copied to `results/<case>/cdfg` by default (or the `output_dir` you set).
### Build Dataset (CDFG + Waveforms)
Configure waveform parsing in `cases/<case>/config.toml` under `[waveform]` and run:
```bash
./main.py build_dataset --case eth_fifo
```
By default the dataset is saved to `results/<case>/dataset/graphs.npz` with metadata in `results/<case>/dataset/meta.json`.
### View CDFG / Dataset
```bash
./main.py view_cdfg --case eth_fifo
./main.py view_dataset --case eth_fifo
```
The commands print a summary to the terminal and save figures under `results/<case>/`.
---
## Legacy RTL2CDFG example
......
......@@ -10,3 +10,22 @@ divide = false
yosys = "yosys"
# Optional override for output location. Defaults to results/<case>/cdfg.
# output_dir = "results/eth_fifo/cdfg"
[waveform]
# Glob is relative to cases/<case>/.
vcd_glob = "sim/*/*/waveform.vcd"
# Signals are relative to the scope below.
scope = "tb_eth_fifo.u_dut"
clock = "clk"
reset = "reset"
reset_active = "high"
# Optional: posedge or negedge.
clock_edge = "posedge"
# Optional: limit number of sampled cycles per trace.
# max_cycles = 200
[dataset]
# Optional override. Defaults to results/<case>/dataset.
# output_dir = "results/eth_fifo/dataset"
# Optional override for CDFG dot if needed.
# cdfg_dot = "results/eth_fifo/cdfg/eth_fifo_ast_clean_cdfg.dot"
#!/usr/bin/env python3
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
COLOR_TO_TYPE = {
"yellow": 1, # input
"green": 2, # output
"orange": 3, # sigset
"black": 4, # wire
"grey": 5, # const
"red": 6, # label
"pink": 7, # comb
}
def sanitize_name(name: str) -> str:
return re.sub(r"[.\\[\\]\\\\]", "_", name)
@dataclass
class GraphData:
node_names: List[str]
node_types: List[int]
edge_index: np.ndarray
edge_type: np.ndarray
edge_label_map: Dict[str, int]
def _parse_attr_map(attr_text: str) -> Dict[str, str]:
attrs: Dict[str, str] = {}
for part in attr_text.split(","):
part = part.strip()
if not part or "=" not in part:
continue
key, val = part.split("=", 1)
attrs[key.strip()] = val.strip().strip("\"")
return attrs
def parse_dot_graph(dot_path: Path) -> Tuple[List[str], Dict[str, Dict[str, str]], List[Tuple[str, str, str]]]:
node_attrs: Dict[str, Dict[str, str]] = {}
node_order: List[str] = []
edges: List[Tuple[str, str, str]] = []
def add_node(name: str, attrs: Optional[Dict[str, str]] = None) -> None:
if name not in node_attrs:
node_order.append(name)
node_attrs[name] = attrs or {}
elif attrs:
node_attrs[name].update(attrs)
with dot_path.open("r") as f:
for raw in f:
line = raw.strip()
if not line or line.startswith("digraph") or line in {"{", "}"}:
continue
if "->" in line:
# Edge line
m = re.search(r'\"?(.+?)\"?\\s*->\\s*\"?(.+?)\"?\\s*\\[label=\"(.*?)\"\\]', line)
if not m:
continue
src, dst, label = m.group(1), m.group(2), m.group(3)
src = src.strip().strip("\"")
dst = dst.strip().strip("\"")
add_node(src)
add_node(dst)
edges.append((src, dst, label))
continue
# Node line
if "[" in line and "]" in line:
name_part, attr_part = line.split("[", 1)
name = name_part.strip().strip(";").strip().strip("\"")
attr_text = attr_part.split("]", 1)[0]
attrs = _parse_attr_map(attr_text)
add_node(name, attrs)
else:
name = line.strip().strip(";").strip().strip("\"")
if name:
add_node(name)
return node_order, node_attrs, edges
def build_graph(dot_path: Path) -> GraphData:
node_order, node_attrs, edges = parse_dot_graph(dot_path)
node_types: List[int] = []
for name in node_order:
attrs = node_attrs.get(name, {})
color = attrs.get("color")
node_types.append(COLOR_TO_TYPE.get(color, 0))
label_map: Dict[str, int] = {}
edge_index: List[List[int]] = [[], []]
edge_type: List[int] = []
node_idx = {name: idx for idx, name in enumerate(node_order)}
for src, dst, label in edges:
if label not in label_map:
label_map[label] = len(label_map)
edge_index[0].append(node_idx[src])
edge_index[1].append(node_idx[dst])
edge_type.append(label_map[label])
edge_index_arr = np.array(edge_index, dtype=np.int64)
edge_type_arr = np.array(edge_type, dtype=np.int64)
return GraphData(node_order, node_types, edge_index_arr, edge_type_arr, label_map)
def _parse_vcd_var_line(tokens: List[str]) -> Tuple[int, str, str]:
width = int(tokens[2])
var_id = tokens[3]
var_name = tokens[4]
if len(tokens) > 6 and tokens[5].startswith("["):
var_name = f"{var_name}{tokens[5]}"
return width, var_id, var_name
def _value_to_int(value: str) -> int:
if not value:
return -1
if any(ch in value for ch in "xXzZ"):
return -1
if value in {"0", "1"}:
return int(value, 2)
return int(value, 2)
def parse_vcd_trace(
vcd_path: Path,
node_names: List[str],
clock: str,
reset: Optional[str],
reset_active: str,
scope: Optional[str] = None,
clock_edge: str = "posedge",
max_cycles: Optional[int] = None,
) -> Tuple[np.ndarray, Dict[str, int], List[str]]:
node_set = set(node_names)
clock_san = sanitize_name(clock)
reset_san = sanitize_name(reset) if reset else None
id_to_signal: Dict[str, str] = {}
signal_widths: Dict[str, int] = {}
current_values: Dict[str, int] = {}
scope_stack: List[str] = []
clock_id: Optional[str] = None
reset_id: Optional[str] = None
def use_signal(rel_name: str, width: int, var_id: str) -> None:
nonlocal clock_id, reset_id
name_san = sanitize_name(rel_name)
if name_san in node_set or name_san in {clock_san, reset_san}:
id_to_signal[var_id] = name_san
signal_widths.setdefault(name_san, width)
current_values.setdefault(name_san, -1)
if name_san == clock_san:
clock_id = var_id
if reset_san and name_san == reset_san:
reset_id = var_id
in_defs = True
in_dumpvars = False
last_clock: Optional[int] = None
pending_sample = False
samples_collected = 0
values_per_node = [[] for _ in node_names]
def record_sample() -> None:
nonlocal samples_collected
if max_cycles is not None and samples_collected >= max_cycles:
return
if reset_id is not None:
reset_val = current_values.get(reset_san, -1)
if reset_val == -1:
return
if reset_active == "high" and reset_val == 1:
return
if reset_active == "low" and reset_val == 0:
return
for idx, name in enumerate(node_names):
values_per_node[idx].append(current_values.get(name, -1))
samples_collected += 1
with vcd_path.open("r") as f:
for raw in f:
line = raw.strip()
if not line:
continue
if line.startswith("$scope"):
tokens = line.split()
if len(tokens) >= 3:
scope_stack.append(tokens[2])
continue
if line.startswith("$upscope"):
if scope_stack:
scope_stack.pop()
continue
if line.startswith("$var"):
tokens = line.split()
if len(tokens) >= 6:
width, var_id, var_name = _parse_vcd_var_line(tokens)
full_scope = ".".join(scope_stack)
if scope:
if full_scope == scope or full_scope.startswith(scope + "."):
rel_scope = full_scope[len(scope):].lstrip(".")
rel_name = f"{rel_scope}.{var_name}" if rel_scope else var_name
use_signal(rel_name, width, var_id)
else:
rel_name = f"{full_scope}.{var_name}" if full_scope else var_name
use_signal(rel_name, width, var_id)
continue
if line.startswith("$enddefinitions"):
in_defs = False
continue
if line.startswith("$dumpvars"):
in_dumpvars = True
continue
if in_dumpvars and line.startswith("$end"):
in_dumpvars = False
continue
if line.startswith("#"):
if pending_sample:
record_sample()
pending_sample = False
if max_cycles is not None and samples_collected >= max_cycles:
break
continue
# Value change
if line[0] in "01xXzZ":
value = line[0]
var_id = line[1:].strip()
elif line[0] in "bBrR":
parts = line.split()
if len(parts) < 2:
continue
value = parts[0][1:]
var_id = parts[1]
else:
continue
if var_id not in id_to_signal:
continue
signal_name = id_to_signal[var_id]
value_int = _value_to_int(value)
current_values[signal_name] = value_int
if var_id == clock_id and value_int != -1:
if last_clock is None:
last_clock = value_int
else:
is_posedge = last_clock == 0 and value_int == 1
is_negedge = last_clock == 1 and value_int == 0
if (clock_edge == "posedge" and is_posedge) or (clock_edge == "negedge" and is_negedge):
pending_sample = True
last_clock = value_int
if pending_sample and (max_cycles is None or samples_collected < max_cycles):
record_sample()
trace = np.array(values_per_node, dtype=np.int64)
mapped_nodes = [name for name in node_names if name in signal_widths]
return trace, signal_widths, mapped_nodes
def build_dataset(
case_name: str,
repo_root: Path,
config: dict,
) -> Path:
cdfg_cfg = config.get("cdfg", {})
dataset_cfg = config.get("dataset", {})
wave_cfg = config.get("waveform", {})
cdfg_output_dir = cdfg_cfg.get("output_dir")
if cdfg_output_dir:
cdfg_dir = (repo_root / cdfg_output_dir).resolve() if not Path(cdfg_output_dir).is_absolute() else Path(cdfg_output_dir)
else:
cdfg_dir = (repo_root / "results" / case_name / "cdfg").resolve()
if not cdfg_dir.exists():
raise FileNotFoundError(f"CDFG directory not found: {cdfg_dir}")
dot_override = dataset_cfg.get("cdfg_dot") or cdfg_cfg.get("dot")
if dot_override:
dot_path = (repo_root / dot_override).resolve() if not Path(dot_override).is_absolute() else Path(dot_override)
else:
dot_files = sorted(cdfg_dir.glob("*.dot"))
if not dot_files:
raise FileNotFoundError(f"No .dot files found in {cdfg_dir}")
dot_path = dot_files[0]
graph = build_graph(dot_path)
vcd_glob = wave_cfg.get("vcd_glob")
vcd_files = wave_cfg.get("vcd_files")
if vcd_files and isinstance(vcd_files, list):
vcd_paths = []
for vf in vcd_files:
path = Path(vf)
if not path.is_absolute():
path = (repo_root / "cases" / case_name / path).resolve()
vcd_paths.append(path)
elif vcd_glob:
case_dir = repo_root / "cases" / case_name
vcd_paths = sorted(case_dir.glob(vcd_glob))
else:
raise KeyError("Missing waveform.vcd_glob or waveform.vcd_files in config.toml")
if not vcd_paths:
raise FileNotFoundError("No VCD files found for waveform parsing.")
clock = wave_cfg.get("clock")
reset = wave_cfg.get("reset")
reset_active = wave_cfg.get("reset_active", "high")
scope = wave_cfg.get("scope")
clock_edge = wave_cfg.get("clock_edge", "posedge")
max_cycles = wave_cfg.get("max_cycles")
if not clock:
raise KeyError("Missing waveform.clock in config.toml")
if reset_active not in {"high", "low"}:
raise ValueError("waveform.reset_active must be 'high' or 'low'")
sim_res: List[np.ndarray] = []
signal_widths: Dict[str, int] = {}
mapped_nodes: set[str] = set()
trace_names: List[str] = []
for vcd_path in vcd_paths:
trace, widths, mapped = parse_vcd_trace(
vcd_path=vcd_path,
node_names=graph.node_names,
clock=clock,
reset=reset,
reset_active=reset_active,
scope=scope,
clock_edge=clock_edge,
max_cycles=max_cycles,
)
sim_res.append(trace)
signal_widths.update(widths)
mapped_nodes.update(mapped)
trace_names.append(str(vcd_path))
has_sim_res = np.array([1 if name in mapped_nodes else 0 for name in graph.node_names], dtype=np.int64)
node_widths = np.array([signal_widths.get(name, 1) for name in graph.node_names], dtype=np.int64)
node_ids = np.arange(len(graph.node_names), dtype=np.int64)
node_types = np.array(graph.node_types, dtype=np.int64)
x = np.stack([node_ids, node_types, node_widths], axis=1)
designs = {
case_name: {
"x": x,
"edge_index": graph.edge_index,
"edge_type": graph.edge_type,
"sim_res": sim_res,
"has_sim_res": has_sim_res,
"power": None,
"slack": None,
"area": None,
}
}
output_dir_cfg = dataset_cfg.get("output_dir")
if output_dir_cfg:
output_dir = (repo_root / output_dir_cfg).resolve() if not Path(output_dir_cfg).is_absolute() else Path(output_dir_cfg)
else:
output_dir = (repo_root / "results" / case_name / "dataset").resolve()
output_dir.mkdir(parents=True, exist_ok=True)
output_npz = output_dir / "graphs.npz"
np.savez_compressed(output_npz, designs=designs)
meta = {
"case": case_name,
"cdfg_dot": str(dot_path),
"vcd_files": trace_names,
"node_names": graph.node_names,
"edge_label_map": graph.edge_label_map,
"clock": clock,
"reset": reset,
"reset_active": reset_active,
"scope": scope,
"clock_edge": clock_edge,
"max_cycles": max_cycles,
}
with (output_dir / "meta.json").open("w") as f:
json.dump(meta, f, indent=2)
return output_npz
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Optional, Callable, List
import os.path as osp
import numpy as np
import torch
import shutil
import os
import copy
import random
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from utils.data_utils import read_npz_file
from parser_func import *
class NpzParser():
'''
Parse the npz file into an inmemory torch_geometric.data.Data object
'''
def __init__(self, data_dir, circuit_path, label_path, \
random_shuffle=False, trainval_split=0.9):
self.data_dir = data_dir
self.dataset = self.inmemory_dataset(data_dir, circuit_path, label_path)
self.trainval_split = trainval_split
self.random_shuffle = random_shuffle
# if random_shuffle:
# perm = torch.randperm(len(dataset))
# dataset = dataset[perm]
# self.train_dataset = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
# self.val_dataset = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
def get_dataset(self, split_with_design=False):
data_len = len(self.dataset)
if not split_with_design:
if self.random_shuffle:
perm = torch.randperm(len(self.dataset))
self.dataset = self.dataset[perm]
training_cutoff = int(data_len * self.trainval_split)
train_dataset = self.dataset[:training_cutoff]
val_dataset = self.dataset[training_cutoff:]
else:
module_name_set = set()
for i in range(len(self.dataset)):
module_name_set.add(self.dataset[i].name.split('_trace')[0])
module_name_list = sorted(module_name_set)
random.seed(666) # the same shuffle each time
random.shuffle(module_name_list)
select_modules = module_name_list[:int(len(module_name_set) * self.trainval_split)]
sel_mask = []
if self.random_shuffle:
perm = torch.randperm(len(self.dataset))
self.dataset = self.dataset[perm]
for i in range(len(self.dataset)):
module_name = self.dataset[i].name.split('_trace')[0]
if module_name in select_modules:
sel_mask.append(True)
else:
sel_mask.append(False)
train_dataset = self.dataset[np.array(sel_mask)]
val_dataset = self.dataset[~np.array(sel_mask)]
return train_dataset, val_dataset
class inmemory_dataset(InMemoryDataset):
def __init__(self, root, graph_path, label_path, transform=None, pre_transform=None, pre_filter=None):
self.name = 'npz_inmm_dataset'
self.root = root
self.graph_path = graph_path
self.label_path = label_path
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_dir(self):
return self.root
@property
def processed_dir(self):
name = 'inmemory-' + os.path.basename(self.graph_path).split('.')[0]
return osp.join(self.root, name)
@property
def raw_file_names(self) -> List[str]:
return [self.graph_path, self.label_path]
@property
def processed_file_names(self) -> str:
return ['data.pt']
def download(self):
pass
def process(self):
data_list = []
designs = read_npz_file(self.graph_path)['designs'].item()
labels = read_npz_file(self.label_path)['labels'].item()
for design_idx, design_name in enumerate(designs):
print('Parse design: {}, {:} / {:} = {:.2f}%'.format(design_name, design_idx, len(designs), design_idx / len(designs) * 100))
x = designs[design_name]['x'] # x[i][0], node_id; x[i][1], node_type; x[i][2], node_width
edge_index = designs[design_name]['edge_index']
edge_type = designs[design_name]['edge_type']
sim_res = designs[design_name]['sim_res']
has_sim_res = designs[design_name]['has_sim_res']
power = designs[design_name]['power']
slack = designs[design_name]['slack']
area = designs[design_name]['area']
if power == None or slack == None or area == None or abs(slack) < 1e-5 or abs(area) < 1e-5 or abs(power[0]) < 1e-5:
continue
y = labels[design_name]['y']
# added by XXXX-5 2024/09/15
# using each simulation result as a separate data
for trace_id, trace in enumerate(sim_res):
graph = parse_pyg_mlpgate(
x=x, edge_index=edge_index, edge_type=edge_type, sim_res=trace, has_sim_res=has_sim_res, y=y,
power=power[trace_id], slack=slack, area=area
)
graph.name = design_name + '_trace' + str(trace_id)
data_list.append(graph)
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
print('[INFO] Inmemory dataset save: ', self.processed_paths[0])
print('Total Designs: {:}'.format(len(data_list)))
def __repr__(self) -> str:
return f'{self.name}({len(self)})'
\ No newline at end of file
#!/usr/bin/env python3
import argparse
import json
import re
import shutil
import subprocess
import sys
......@@ -22,6 +24,69 @@ def _load_case_config(case_name: str) -> dict:
return tomllib.load(f)
def _import_dataset_builder():
try:
import dataset_builder
except Exception as exc:
raise RuntimeError(
"Failed to import dataset_builder. Ensure numpy is installed for dataset operations."
) from exc
return dataset_builder
def _parse_dot_graph(dot_path: Path):
node_attrs = {}
node_order = []
edges = []
def add_node(name: str, attrs=None) -> None:
if name not in node_attrs:
node_order.append(name)
node_attrs[name] = attrs or {}
elif attrs:
node_attrs[name].update(attrs)
def parse_attr_map(attr_text: str):
attrs = {}
for part in attr_text.split(","):
part = part.strip()
if not part or "=" not in part:
continue
key, val = part.split("=", 1)
attrs[key.strip()] = val.strip().strip("\"")
return attrs
with dot_path.open("r") as f:
for raw in f:
line = raw.strip()
if not line or line.startswith("digraph") or line in {"{", "}"}:
continue
if "->" in line:
m = re.search(r'\"?(.+?)\"?\\s*->\\s*\"?(.+?)\"?\\s*\\[label=\"(.*?)\"\\]', line)
if not m:
continue
src, dst, label = m.group(1), m.group(2), m.group(3)
src = src.strip().strip("\"")
dst = dst.strip().strip("\"")
add_node(src)
add_node(dst)
edges.append((src, dst, label))
continue
if "[" in line and "]" in line:
name_part, attr_part = line.split("[", 1)
name = name_part.strip().strip(";").strip().strip("\"")
attr_text = attr_part.split("]", 1)[0]
attrs = parse_attr_map(attr_text)
add_node(name, attrs)
else:
name = line.strip().strip(";").strip().strip("\"")
if name:
add_node(name)
return node_order, node_attrs, edges
def _resolve_rtl_files(case_dir: Path, rtl_files: list[str]) -> list[str]:
resolved = []
missing = []
......@@ -40,6 +105,46 @@ def _resolve_rtl_files(case_dir: Path, rtl_files: list[str]) -> list[str]:
return resolved
def _resolve_cdfg_dir(case_name: str, config: dict) -> Path:
cdfg_cfg = config.get("cdfg", {})
output_dir_cfg = cdfg_cfg.get("output_dir")
if output_dir_cfg:
output_dir = Path(output_dir_cfg)
if not output_dir.is_absolute():
output_dir = (REPO_ROOT / output_dir).resolve()
else:
output_dir = (REPO_ROOT / "results" / case_name / "cdfg").resolve()
return output_dir
def _select_cdfg_dot(case_name: str, config: dict, cdfg_dir: Path) -> Path:
dataset_cfg = config.get("dataset", {})
cdfg_cfg = config.get("cdfg", {})
dot_override = dataset_cfg.get("cdfg_dot") or cdfg_cfg.get("dot")
if dot_override:
dot_path = Path(dot_override)
if not dot_path.is_absolute():
dot_path = (REPO_ROOT / dot_path).resolve()
return dot_path
dot_files = sorted(cdfg_dir.glob("*.dot"))
if not dot_files:
raise FileNotFoundError(f"No .dot files found in {cdfg_dir}")
return dot_files[0]
def _render_dot(dot_path: Path, output_png: Path) -> bool:
dot_bin = shutil.which("dot")
if not dot_bin:
print("[view] Graphviz 'dot' not found. Skipping PNG render.")
return False
output_png.parent.mkdir(parents=True, exist_ok=True)
result = subprocess.run([dot_bin, "-Tpng", str(dot_path), "-o", str(output_png)])
if result.returncode != 0:
print("[view] dot failed to render PNG.")
return False
return True
def _build_cdfg(case_name: str) -> int:
case_dir = REPO_ROOT / "cases" / case_name
if not case_dir.exists():
......@@ -84,13 +189,7 @@ def _build_cdfg(case_name: str) -> int:
if result.returncode != 0:
return result.returncode
output_dir_cfg = cdfg_cfg.get("output_dir")
if output_dir_cfg:
output_dir = Path(output_dir_cfg)
if not output_dir.is_absolute():
output_dir = (REPO_ROOT / output_dir).resolve()
else:
output_dir = (REPO_ROOT / "results" / case_name / "cdfg").resolve()
output_dir = _resolve_cdfg_dir(case_name, config)
if output_dir.exists():
shutil.rmtree(output_dir)
output_dir.parent.mkdir(parents=True, exist_ok=True)
......@@ -100,6 +199,98 @@ def _build_cdfg(case_name: str) -> int:
return 0
def _view_cdfg(case_name: str) -> int:
config = _load_case_config(case_name)
cdfg_dir = _resolve_cdfg_dir(case_name, config)
dot_path = _select_cdfg_dot(case_name, config, cdfg_dir)
node_order, node_attrs, edges = _parse_dot_graph(dot_path)
color_counts = {}
for name in node_order:
color = node_attrs.get(name, {}).get("color", "unknown")
color_counts[color] = color_counts.get(color, 0) + 1
edge_label_counts = {}
for _, _, label in edges:
edge_label_counts[label] = edge_label_counts.get(label, 0) + 1
print("[view_cdfg] Summary")
print(f"- dot: {dot_path}")
print(f"- nodes: {len(node_order)}")
print(f"- edges: {len(edges)}")
print(f"- edge labels: {len(edge_label_counts)}")
print("[view_cdfg] Node colors")
for color, count in sorted(color_counts.items(), key=lambda x: (-x[1], x[0])):
print(f"- {color}: {count}")
print("[view_cdfg] Top edge labels")
for label, count in sorted(edge_label_counts.items(), key=lambda x: (-x[1], x[0]))[:10]:
print(f"- {label}: {count}")
output_png = cdfg_dir / "view_cdfg.png"
if _render_dot(dot_path, output_png):
print(f"[view_cdfg] Figure saved to {output_png}")
return 0
def _view_dataset(case_name: str) -> int:
config = _load_case_config(case_name)
dataset_cfg = config.get("dataset", {})
output_dir_cfg = dataset_cfg.get("output_dir")
if output_dir_cfg:
dataset_dir = Path(output_dir_cfg)
if not dataset_dir.is_absolute():
dataset_dir = (REPO_ROOT / dataset_dir).resolve()
else:
dataset_dir = (REPO_ROOT / "results" / case_name / "dataset").resolve()
graphs_npz = dataset_dir / "graphs.npz"
meta_json = dataset_dir / "meta.json"
print("[view_dataset] Summary")
print(f"- dataset dir: {dataset_dir}")
if meta_json.exists():
with meta_json.open("r") as f:
meta = json.load(f)
print(f"- case: {meta.get('case')}")
print(f"- cdfg_dot: {meta.get('cdfg_dot')}")
print(f"- traces: {len(meta.get('vcd_files', []))}")
print(f"- nodes (meta): {len(meta.get('node_names', []))}")
else:
print("- meta.json: not found")
try:
import numpy as np # noqa: F401
except Exception:
print("[view_dataset] numpy not available; skipping npz parsing.")
else:
if graphs_npz.exists():
with np.load(graphs_npz, allow_pickle=True) as data:
designs = data["designs"].item()
print(f"- designs: {len(designs)}")
sample_name = next(iter(designs))
sample = designs[sample_name]
x = sample.get("x")
edge_index = sample.get("edge_index")
sim_res = sample.get("sim_res")
print(f"- sample design: {sample_name}")
print(f"- x shape: {getattr(x, 'shape', None)}")
print(f"- edge_index shape: {getattr(edge_index, 'shape', None)}")
print(f"- traces: {len(sim_res) if sim_res is not None else 0}")
if sim_res:
print(f"- trace[0] shape: {getattr(sim_res[0], 'shape', None)}")
else:
print(f"- graphs.npz: not found at {graphs_npz}")
cdfg_dir = _resolve_cdfg_dir(case_name, config)
dot_path = _select_cdfg_dot(case_name, config, cdfg_dir)
graph_png = dataset_dir / "view_dataset_graph.png"
if _render_dot(dot_path, graph_png):
print(f"[view_dataset] Graph figure saved to {graph_png}")
return 0
def main() -> int:
parser = argparse.ArgumentParser(description="CDFG utilities")
subparsers = parser.add_subparsers(dest="command", required=True)
......@@ -107,10 +298,29 @@ def main() -> int:
build_parser = subparsers.add_parser("build_cdfg", help="Generate CDFG from RTL")
build_parser.add_argument("--case", required=True, help="Case name under cases/")
dataset_parser = subparsers.add_parser("build_dataset", help="Generate dataset from CDFG + waveforms")
dataset_parser.add_argument("--case", required=True, help="Case name under cases/")
view_cdfg_parser = subparsers.add_parser("view_cdfg", help="Summarize and render CDFG")
view_cdfg_parser.add_argument("--case", required=True, help="Case name under cases/")
view_dataset_parser = subparsers.add_parser("view_dataset", help="Summarize and render dataset")
view_dataset_parser.add_argument("--case", required=True, help="Case name under cases/")
args = parser.parse_args()
if args.command == "build_cdfg":
return _build_cdfg(args.case)
if args.command == "build_dataset":
dataset_builder = _import_dataset_builder()
config = _load_case_config(args.case)
output_npz = dataset_builder.build_dataset(args.case, REPO_ROOT, config)
print(f"[build_dataset] Dataset saved to {output_npz}")
return 0
if args.command == "view_cdfg":
return _view_cdfg(args.case)
if args.command == "view_dataset":
return _view_dataset(args.case)
parser.print_help()
return 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment