Commit ce346a35 by lvzhengyang

impl. the train logic

parent d8d9890b
......@@ -38,6 +38,28 @@ Configure waveform parsing in `cases/<case>/config.toml` under `[waveform]` and
By default the dataset is saved to `results/<case>/dataset/graphs.npz` with metadata in `results/<case>/dataset/meta.json`.
### Stage-1 Training (DynamicRTL-style)
Stage-1 training expects operator-type node features and bitvector waveforms. For `eth_fifo`, configure `[dataset]` in `cases/eth_fifo/config.toml`:
```
node_type_mode = "category"
edge_type_mode = "numeric"
sim_res_mode = "bitvector"
sim_res_width = 32
build_labels = true
op_to_index_path = "operator_types/op_to_index.json"
category_map_path = "operator_types/category_map.json"
signal_ops = ["Input", "Output", "Wire", "Reg"]
```
Rebuild the dataset and launch training:
```bash
./main.py build_dataset --case eth_fifo
python3 train_stage1.py --case eth_fifo --model default_shared
```
### View CDFG / Dataset
```bash
......
......@@ -29,3 +29,11 @@ clock_edge = "posedge"
# output_dir = "results/eth_fifo/dataset"
# Optional override for CDFG dot if needed.
# cdfg_dot = "results/eth_fifo/cdfg/eth_fifo_ast_clean_cdfg.dot"
node_type_mode = "category"
edge_type_mode = "numeric"
sim_res_mode = "bitvector"
sim_res_width = 32
build_labels = true
op_to_index_path = "operator_types/op_to_index.json"
category_map_path = "operator_types/category_map.json"
signal_ops = ["Input", "Output", "Wire", "Reg"]
{
"Add": "Add",
"BitAnd": "BitAnd",
"BitXor": "BitXor",
"Concat": "Concat",
"Cond": "Cond",
"Cond_If": "Cond_If",
"Const": "Const",
"Eq": "Eq",
"Input": "Input",
"Not": "Not",
"Output": "Output",
"PartSelect": "PartSelect",
"Reg": "Reg",
"Sub": "Sub",
"URand": "URand",
"URor": "URor",
"Wire": "Wire"
}
\ No newline at end of file
{
"Add": 0,
"BitAnd": 1,
"BitXor": 2,
"Concat": 3,
"Cond": 4,
"Cond_If": 5,
"Const": 6,
"Eq": 7,
"Input": 8,
"Not": 9,
"Output": 10,
"PartSelect": 11,
"Reg": 12,
"Sub": 13,
"URand": 14,
"URor": 15,
"Wire": 16
}
\ No newline at end of file
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from pathlib import Path
import numpy as np
import torch
REPO_ROOT = Path(__file__).resolve().parent
MODEL_EXAMPLE_SRC = REPO_ROOT / "model_example" / "src"
def _ensure_op_to_index(data_dir: Path) -> None:
meta_path = data_dir / "meta.json"
op_to_index = None
op_path = REPO_ROOT / "operator_types" / "op_to_index.json"
if meta_path.exists():
meta = json.loads(meta_path.read_text())
op_to_index = meta.get("op_to_index")
op_path_cfg = meta.get("op_to_index_path")
if op_path_cfg:
cfg_path = Path(op_path_cfg)
if not cfg_path.is_absolute():
cfg_path = (REPO_ROOT / cfg_path).resolve()
op_path = cfg_path
if op_to_index is None:
node_names = meta.get("node_names", [])
ops = sorted({name.split(",")[0] for name in node_names if name})
if ops:
op_to_index = {op: idx for idx, op in enumerate(ops)}
if op_to_index is None:
print("[train_stage1] Warning: op_to_index not found; training may fail.")
return
op_path.parent.mkdir(parents=True, exist_ok=True)
op_path.write_text(json.dumps(op_to_index, indent=2))
print(f"[train_stage1] op_to_index saved to {op_path}")
required_keys = {"Input", "Const", "Wire", "Reg", "Cond", "Output"}
missing = required_keys - set(op_to_index.keys())
if missing:
print(f"[train_stage1] Warning: op_to_index missing keys {sorted(missing)}. "
"If you merged categories, update the model or mapping accordingly.")
def _ensure_labels(data_dir: Path, graph_npz: Path, label_npz: Path) -> None:
if label_npz.exists():
return
if not graph_npz.exists():
raise FileNotFoundError(f"graphs.npz not found: {graph_npz}")
with np.load(graph_npz, allow_pickle=True) as data:
designs = data["designs"].item()
labels = {name: {"y": 0} for name in designs}
label_npz.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(label_npz, labels=labels)
print(f"[train_stage1] labels.npz created at {label_npz}")
def _select_device(device_arg: str) -> torch.device:
if device_arg.startswith("cuda") and torch.cuda.is_available():
return torch.device(device_arg)
return torch.device("cpu")
def main() -> int:
parser = argparse.ArgumentParser(description="Stage-1 training (DynamicRTL style)")
parser.add_argument("--case", required=True, help="Case name under cases/")
parser.add_argument("--data_dir", default="", help="Dataset dir (default: results/<case>/dataset)")
parser.add_argument("--graph_npz_name", default="graphs.npz", type=str)
parser.add_argument("--label_npz_name", default="labels.npz", type=str)
parser.add_argument("--distributed", action="store_true", help="If set, train in distributed mode")
parser.add_argument("--num_workers", default=4, type=int)
parser.add_argument("--batch_size", default=64, type=int)
parser.add_argument("--lr", default=1e-4, type=float)
parser.add_argument("--lr_step", default=50, type=int)
parser.add_argument("--num_epochs", default=60, type=int)
parser.add_argument("--num_rounds", default=20, type=int, help="Number of rounds to GNN propagate")
parser.add_argument("--train_seq_len", default=50, type=int)
parser.add_argument("--eval_seq_len", default=50, type=int)
parser.add_argument("--device", default="cuda", type=str, help="cpu / cuda / cuda:0")
parser.add_argument("--gpus", default="0", type=str, help="GPU IDs to use, example: 0,1,2,3")
parser.add_argument("--model", default="default_shared", type=str, help="default_shared or default")
parser.add_argument("--exp_id", default="stage1", type=str, help="Experiment ID")
parser.add_argument("--supervision", default="default", type=str, help="only_branch/only_tgl/default")
parser.add_argument("--split_with_design", action="store_true", help="Split dataset by design name")
args = parser.parse_args()
data_dir = Path(args.data_dir) if args.data_dir else (REPO_ROOT / "results" / args.case / "dataset")
graph_npz = data_dir / args.graph_npz_name
label_npz = data_dir / args.label_npz_name
os.chdir(REPO_ROOT)
_ensure_op_to_index(data_dir)
_ensure_labels(data_dir, graph_npz, label_npz)
sys.path.insert(0, str(MODEL_EXAMPLE_SRC))
from npz_parser import NpzParser # noqa: E402
from model_arch import Model_default, Model_shared # noqa: E402
from trainer import Trainer # noqa: E402
model_factory = {
"default_shared": Model_shared,
"default": Model_default,
}
if args.model not in model_factory:
raise ValueError(f"Model not supported: {args.model}")
if args.model == "default":
print("[train_stage1] Warning: Model_default assumes fixed operator indices. Prefer --model default_shared.")
device = _select_device(args.device)
print(f"[train_stage1] Using device: {device}")
dataset = NpzParser(str(data_dir), str(graph_npz), str(label_npz))
train_dataset, val_dataset = dataset.get_dataset(split_with_design=args.split_with_design)
model = model_factory[args.model](num_rounds=args.num_rounds)
time_str = time.strftime("%Y-%m-%d-%H-%M")
trainer = Trainer(
args,
model,
distributed=args.distributed,
batch_size=args.batch_size,
device=device,
gpus=args.gpus,
training_id=time_str,
)
trainer.set_training_args(lr=args.lr, lr_step=args.lr_step)
print("[train_stage1] Stage 1 Training ...")
trainer.train(
args.num_epochs,
train_dataset,
val_dataset,
train_seq_len=args.train_seq_len,
eval_seq_len=args.eval_seq_len,
supervision=args.supervision,
)
print("[train_stage1] Finish Training")
return 0
if __name__ == "__main__":
raise SystemExit(main())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment