Commit 165224f9 by lvzhengyang

add property alignment training

parent ce346a35
# Property-Alignment Embedding Plan
## 1. Idea Summary (Refined)
We replace the original Stage-1 objectives (branch-hit, toggle rate) with a **property-alignment objective**.
Each property is a Boolean expression (e.g., `~a`, `a|b`, `~a & b`) built from **variables + operators**.
We embed each property into a vector space where:
- **Implication ⇒ parallel**: If property P implies Q, then `p` and `q` are parallel.
- **Specificity ⇒ shorter**: If P implies Q, then `||p|| < ||q||` (P is more specific).
- **Conflict ⇒ orthogonal**: If P conflicts with Q, then `p · q ≈ 0`.
This embedding space should encode logical structure and dynamic behavior.
---
## 2. Property Embedding Space Definition
Let each property embedding be a vector `p ∈ R^d`.
Constraints:
- **Implication direction**:
- `L_dir = 1 - cos(p, q)`
- **Implication length ordering**:
- `L_len = max(0, ||p|| - ||q|| + margin)`
- **Conflict orthogonality**:
- `L_conf = (p · q)^2`
Optional grounding:
- Predict truth probability from property embedding:
- `L_truth = BCE(sigmoid(w^T p), truth)`
Total loss:
```
L = w1 * L_dir + w2 * L_len + w3 * L_conf + w4 * L_truth
```
---
## 3. Mapping Variables + Operators to Property Embeddings
### 3.1 Variable embeddings
Use the existing node encoder (graph + waveform) to produce node embeddings:
- `e_v = GNN(node_v)`
### 3.2 Operator composition
Define small learnable composition functions per operator:
- `NOT(x) = f_not(x)`
- `AND(x,y) = f_and([x;y])`
- `OR(x,y) = f_or([x;y])`
Each property is a tree:
```
prop_emb = ComposeTree(operators, variable_embeddings)
```
---
## 4. How to Get Training Constraints
From waveform traces:
1. Evaluate each property per time-step (truth vector).
2. Derive relations:
- **Implication**: `P ⇒ Q` if `truth(P) ≤ truth(Q)` across all steps.
- **Conflict**: `P ⟂ Q` if `(truth(P) & truth(Q)) == 0` across all steps.
This yields automatic pairs for `L_dir`, `L_len`, and `L_conf`.
---
## 5. Implementation Plan (Code)
### 5.1 New modules
- `property_expr.py`
- Represent property trees
- Evaluate property on waveform traces
- Enumerate properties (depth-1, depth-2)
- `property_encoder.py`
- Operator embeddings / composition functions
- Produce property embedding from node embeddings
- `trainer_property.py`
- Build implication/conflict pairs
- Compute property losses
### 5.2 Training pipeline changes
- Keep current node encoder (Graph + seq2seq + GNN)
- Replace stage-1 losses with property-alignment losses
- Optionally freeze encoder at start, then finetune
---
## 6. Experiment Plan
### Experiment A: Shallow properties
- Operators: `NOT`, `AND`, `OR`
- Properties: `~a`, `a|b`, `a&b`, `~a & b`
- Metrics:
- Cosine similarity for implication pairs (↑)
- Dot product for conflict pairs (→ 0)
- Length ordering accuracy (↑)
### Experiment B: Depth-2 properties
- Properties with nesting, e.g. `(a|b)&c`, `~(a&b)`
- Check if geometry still holds
### Experiment C: Downstream utility
Use learned property embeddings to predict:
- Assertions
- Power / area / slack
Compare vs baseline (branch + toggle Stage-1)
---
## 7. Risks & Mitigations
- **Degenerate embeddings** (all vectors zero or same direction):
- Add `L_truth` grounding
- Normalize only in cosine term, not globally
- **Operator category mapping mismatch**:
- Keep category map aligned with operator set
- Check required ops (`Input`, `Const`, `Cond`, `Output`)
- **Data sparsity for implication pairs**:
- Balance pair sampling
- Use shallow formulas first
---
## 8. Milestones
1. Property generation + evaluation on waveforms
2. Property embedding + loss implementation
3. Train on shallow formulas (Experiment A)
4. Extend to depth-2 (Experiment B)
5. Downstream validation (Experiment C)
...@@ -60,6 +60,22 @@ Rebuild the dataset and launch training: ...@@ -60,6 +60,22 @@ Rebuild the dataset and launch training:
python3 train_stage1.py --case eth_fifo --model default_shared python3 train_stage1.py --case eth_fifo --model default_shared
``` ```
### Property-Alignment Training
Generate the dataset and run property training:
```bash
./main.py build_dataset --case eth_fifo
./main.py train_property --case eth_fifo -- --model default_shared
```
Evaluate or test a trained property encoder:
```bash
python3 eval_property.py --case eth_fifo --checkpoint results/eth_fifo/property/property/property_encoder_last.pth
python3 test_property.py --case eth_fifo --checkpoint results/eth_fifo/property/property/property_encoder_last.pth
```
### View CDFG / Dataset ### View CDFG / Dataset
```bash ```bash
......
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
import torch
from torch_geometric.loader import DataLoader
from property_expr import generate_properties
from property_encoder import PropertyEncoder
from property_trainer import PropertyTrainer, PropertyLossConfig, PropertyCache
REPO_ROOT = Path(__file__).resolve().parent
MODEL_EXAMPLE_SRC = REPO_ROOT / "model_example" / "src"
def _select_device(device_arg: str) -> torch.device:
if device_arg.startswith("cuda") and torch.cuda.is_available():
return torch.device(device_arg)
return torch.device("cpu")
def _load_meta(data_dir: Path) -> dict:
meta_path = data_dir / "meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"meta.json not found in {data_dir}")
return json.loads(meta_path.read_text())
def _get_var_indices(meta: dict, signal_ops: list[str]) -> list[int]:
node_names = meta.get("node_names", [])
var_indices = []
for idx, name in enumerate(node_names):
op = name.split(",")[0] if name else ""
if op in signal_ops:
var_indices.append(idx)
return var_indices
def main() -> int:
parser = argparse.ArgumentParser(description="Property-alignment evaluation")
parser.add_argument("--case", required=True, help="Case name under cases/")
parser.add_argument("--data_dir", default="", help="Dataset dir (default: results/<case>/dataset)")
parser.add_argument("--graph_npz_name", default="graphs.npz", type=str)
parser.add_argument("--label_npz_name", default="labels.npz", type=str)
parser.add_argument("--checkpoint", required=True, help="Property encoder checkpoint path")
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--num_rounds", default=20, type=int)
parser.add_argument("--model", default="default_shared", type=str)
parser.add_argument("--eval_seq_len", default=50, type=int)
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--max_props", default=64, type=int)
parser.add_argument("--max_pairs", default=512, type=int)
parser.add_argument("--signal_ops", default="Input,Output,Wire,Reg", type=str)
parser.add_argument("--include_ops", default="VAR,NOT,AND,OR,XOR,IMPLIES,EQUIV", type=str)
parser.add_argument("--cache_dir", default="", type=str)
parser.add_argument("--no_cache", action="store_true")
args = parser.parse_args()
data_dir = Path(args.data_dir) if args.data_dir else (REPO_ROOT / "results" / args.case / "dataset")
graph_npz = data_dir / args.graph_npz_name
label_npz = data_dir / args.label_npz_name
os.chdir(REPO_ROOT)
sys.path.insert(0, str(MODEL_EXAMPLE_SRC))
from npz_parser import NpzParser # noqa: E402
from model_arch import Model_default, Model_shared # noqa: E402
model_factory = {
"default_shared": Model_shared,
"default": Model_default,
}
if args.model not in model_factory:
raise ValueError(f"Model not supported: {args.model}")
device = _select_device(args.device)
meta = _load_meta(data_dir)
signal_ops = [s.strip() for s in args.signal_ops.split(",") if s.strip()]
var_indices = _get_var_indices(meta, signal_ops)
include_ops = tuple(s.strip() for s in args.include_ops.split(",") if s.strip())
props = generate_properties(var_indices, max_props=args.max_props, include_ops=include_ops)
dataset = NpzParser(str(data_dir), str(graph_npz), str(label_npz))
_, val_dataset = dataset.get_dataset(split_with_design=True)
if args.batch_size != 1:
raise ValueError("Property evaluation assumes batch_size=1 for correct indexing.")
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
node_encoder = model_factory[args.model](num_rounds=args.num_rounds)
prop_encoder = PropertyEncoder(node_dim=node_encoder.dim_hidden)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
prop_encoder.load_state_dict(checkpoint["state_dict"])
trainer = PropertyTrainer(node_encoder=node_encoder, prop_encoder=prop_encoder, device=device, freeze_node_encoder=True)
loss_cfg = PropertyLossConfig()
cache = None
if not args.no_cache:
cache_dir = Path(args.cache_dir) if args.cache_dir else (REPO_ROOT / "results" / args.case / "property" / "eval_cache")
cache = PropertyCache(cache_dir)
val_losses = []
for batch in val_loader:
batch = batch.to(device)
losses = trainer.eval_step(batch, props, loss_cfg, args.eval_seq_len, args.max_pairs, cache=cache)
val_losses.append(losses)
def avg(key, items):
return sum(x[key] for x in items) / max(1, len(items))
log = {
"val_total": avg("total", val_losses),
"val_dir": avg("loss_dir", val_losses),
"val_len": avg("loss_len", val_losses),
"val_conf": avg("loss_conf", val_losses),
}
print(json.dumps(log))
return 0
if __name__ == "__main__":
raise SystemExit(main())
...@@ -311,6 +311,10 @@ def main() -> int: ...@@ -311,6 +311,10 @@ def main() -> int:
view_dataset_parser.add_argument("--case", required=True, help="Case name under cases/") view_dataset_parser.add_argument("--case", required=True, help="Case name under cases/")
view_dataset_parser.add_argument("--figure", action="store_true", help="Render and save the dataset figure") view_dataset_parser.add_argument("--figure", action="store_true", help="Render and save the dataset figure")
train_prop_parser = subparsers.add_parser("train_property", help="Run property-alignment training")
train_prop_parser.add_argument("--case", required=True, help="Case name under cases/")
train_prop_parser.add_argument("extra_args", nargs=argparse.REMAINDER, help="Extra args passed to train_property.py")
args = parser.parse_args() args = parser.parse_args()
if args.command == "build_cdfg": if args.command == "build_cdfg":
...@@ -325,6 +329,11 @@ def main() -> int: ...@@ -325,6 +329,11 @@ def main() -> int:
return _view_cdfg(args.case, args.figure) return _view_cdfg(args.case, args.figure)
if args.command == "view_dataset": if args.command == "view_dataset":
return _view_dataset(args.case, args.figure) return _view_dataset(args.case, args.figure)
if args.command == "train_property":
cmd = ["python3", str(REPO_ROOT / "train_property.py"), "--case", args.case]
if args.extra_args:
cmd.extend(args.extra_args)
return subprocess.run(cmd).returncode
parser.print_help() parser.print_help()
return 1 return 1
......
#!/usr/bin/env python3
from __future__ import annotations
from typing import List
import torch
from torch import nn
from property_expr import PropertyExpr
class PropertyEncoder(nn.Module):
def __init__(self, node_dim: int, hidden_dim: int | None = None):
super().__init__()
hidden_dim = hidden_dim or node_dim
self.not_mlp = nn.Sequential(
nn.Linear(node_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, node_dim),
)
self.and_mlp = nn.Sequential(
nn.Linear(node_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, node_dim),
)
self.or_mlp = nn.Sequential(
nn.Linear(node_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, node_dim),
)
self.xor_mlp = nn.Sequential(
nn.Linear(node_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, node_dim),
)
self.implies_mlp = nn.Sequential(
nn.Linear(node_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, node_dim),
)
self.equiv_mlp = nn.Sequential(
nn.Linear(node_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, node_dim),
)
def encode(self, expr: PropertyExpr, node_emb: torch.Tensor) -> torch.Tensor:
if expr.op == "VAR":
return node_emb[expr.var_idx]
if expr.op == "NOT":
child = self.encode(expr.args[0], node_emb)
return self.not_mlp(child)
if expr.op == "AND":
left = self.encode(expr.args[0], node_emb)
right = self.encode(expr.args[1], node_emb)
return self.and_mlp(torch.cat([left, right], dim=-1))
if expr.op == "OR":
left = self.encode(expr.args[0], node_emb)
right = self.encode(expr.args[1], node_emb)
return self.or_mlp(torch.cat([left, right], dim=-1))
if expr.op == "XOR":
left = self.encode(expr.args[0], node_emb)
right = self.encode(expr.args[1], node_emb)
return self.xor_mlp(torch.cat([left, right], dim=-1))
if expr.op == "IMPLIES":
left = self.encode(expr.args[0], node_emb)
right = self.encode(expr.args[1], node_emb)
return self.implies_mlp(torch.cat([left, right], dim=-1))
if expr.op == "EQUIV":
left = self.encode(expr.args[0], node_emb)
right = self.encode(expr.args[1], node_emb)
return self.equiv_mlp(torch.cat([left, right], dim=-1))
raise ValueError(f"Unsupported op: {expr.op}")
def batch_encode(self, exprs: List[PropertyExpr], node_emb: torch.Tensor) -> torch.Tensor:
embeds = [self.encode(expr, node_emb) for expr in exprs]
return torch.stack(embeds, dim=0)
#!/usr/bin/env python3
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Tuple
import random
@dataclass(frozen=True)
class PropertyExpr:
op: str
args: Tuple["PropertyExpr", ...] = ()
var_idx: int | None = None
def to_string(self) -> str:
if self.op == "VAR":
return f"v{self.var_idx}"
if self.op == "NOT":
return f"~({self.args[0].to_string()})"
if self.op == "AND":
return f"({self.args[0].to_string()} & {self.args[1].to_string()})"
if self.op == "OR":
return f"({self.args[0].to_string()} | {self.args[1].to_string()})"
if self.op == "XOR":
return f"({self.args[0].to_string()} ^ {self.args[1].to_string()})"
if self.op == "IMPLIES":
return f"({self.args[0].to_string()} -> {self.args[1].to_string()})"
if self.op == "EQUIV":
return f"({self.args[0].to_string()} <-> {self.args[1].to_string()})"
return self.op
def generate_properties(
var_indices: List[int],
max_props: int = 64,
include_ops: Tuple[str, ...] = ("VAR", "NOT", "AND", "OR", "XOR", "IMPLIES", "EQUIV"),
seed: int = 1234,
) -> List[PropertyExpr]:
rng = random.Random(seed)
props: List[PropertyExpr] = []
if "VAR" in include_ops:
for v in var_indices:
props.append(PropertyExpr(op="VAR", var_idx=v))
if "NOT" in include_ops:
for v in var_indices:
props.append(PropertyExpr(op="NOT", args=(PropertyExpr("VAR", var_idx=v),)))
if any(op in include_ops for op in ("AND", "OR", "XOR", "IMPLIES", "EQUIV")):
pairs = [(a, b) for a in var_indices for b in var_indices if a != b]
rng.shuffle(pairs)
for a, b in pairs:
if "AND" in include_ops:
props.append(PropertyExpr(op="AND", args=(PropertyExpr("VAR", var_idx=a), PropertyExpr("VAR", var_idx=b))))
if "OR" in include_ops:
props.append(PropertyExpr(op="OR", args=(PropertyExpr("VAR", var_idx=a), PropertyExpr("VAR", var_idx=b))))
if "XOR" in include_ops:
props.append(PropertyExpr(op="XOR", args=(PropertyExpr("VAR", var_idx=a), PropertyExpr("VAR", var_idx=b))))
if "IMPLIES" in include_ops:
props.append(PropertyExpr(op="IMPLIES", args=(PropertyExpr("VAR", var_idx=a), PropertyExpr("VAR", var_idx=b))))
if "EQUIV" in include_ops:
props.append(PropertyExpr(op="EQUIV", args=(PropertyExpr("VAR", var_idx=a), PropertyExpr("VAR", var_idx=b))))
if len(props) >= max_props:
break
if len(props) > max_props:
rng.shuffle(props)
props = props[:max_props]
return props
#!/usr/bin/env python3
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Tuple
import hashlib
from pathlib import Path
import torch
import torch.nn.functional as F
from property_expr import PropertyExpr
from property_encoder import PropertyEncoder
def compute_var_truth(sim_res: torch.Tensor, seq_len: int) -> torch.Tensor:
# sim_res: [N, T, B]
return (sim_res[:, :seq_len, :] != 0).any(dim=2).to(torch.bool)
def eval_property(expr: PropertyExpr, var_truth: torch.Tensor) -> torch.Tensor:
# var_truth: [N, T]
if expr.op == "VAR":
return var_truth[expr.var_idx]
if expr.op == "NOT":
return ~eval_property(expr.args[0], var_truth)
if expr.op == "AND":
return eval_property(expr.args[0], var_truth) & eval_property(expr.args[1], var_truth)
if expr.op == "OR":
return eval_property(expr.args[0], var_truth) | eval_property(expr.args[1], var_truth)
if expr.op == "XOR":
return eval_property(expr.args[0], var_truth) ^ eval_property(expr.args[1], var_truth)
if expr.op == "IMPLIES":
left = eval_property(expr.args[0], var_truth)
right = eval_property(expr.args[1], var_truth)
return (~left) | right
if expr.op == "EQUIV":
left = eval_property(expr.args[0], var_truth)
right = eval_property(expr.args[1], var_truth)
return left == right
raise ValueError(f"Unsupported op: {expr.op}")
def build_relation_pairs(
prop_truth: torch.Tensor,
max_pairs: int = 512,
) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]:
# prop_truth: [P, T] bool
num_props = prop_truth.size(0)
implication_pairs = []
conflict_pairs = []
for i in range(num_props):
for j in range(num_props):
if i == j:
continue
pi = prop_truth[i]
pj = prop_truth[j]
if torch.all(~pi | pj):
implication_pairs.append((i, j))
if torch.all(~(pi & pj)):
conflict_pairs.append((i, j))
if len(implication_pairs) > max_pairs:
implication_pairs = implication_pairs[:max_pairs]
if len(conflict_pairs) > max_pairs:
conflict_pairs = conflict_pairs[:max_pairs]
return implication_pairs, conflict_pairs
@dataclass
class PropertyLossConfig:
weight_dir: float = 1.0
weight_len: float = 1.0
weight_conf: float = 1.0
margin: float = 0.1
def compute_property_loss(
prop_emb: torch.Tensor,
implication_pairs: List[Tuple[int, int]],
conflict_pairs: List[Tuple[int, int]],
cfg: PropertyLossConfig,
) -> Dict[str, torch.Tensor]:
loss_dir = torch.tensor(0.0, device=prop_emb.device)
loss_len = torch.tensor(0.0, device=prop_emb.device)
loss_conf = torch.tensor(0.0, device=prop_emb.device)
if implication_pairs:
idx_i = torch.tensor([i for i, _ in implication_pairs], device=prop_emb.device)
idx_j = torch.tensor([j for _, j in implication_pairs], device=prop_emb.device)
pi = prop_emb[idx_i]
pj = prop_emb[idx_j]
cos = F.cosine_similarity(pi, pj, dim=1)
loss_dir = (1.0 - cos).mean()
len_i = pi.norm(dim=1)
len_j = pj.norm(dim=1)
loss_len = F.relu(len_i - len_j + cfg.margin).mean()
if conflict_pairs:
idx_i = torch.tensor([i for i, _ in conflict_pairs], device=prop_emb.device)
idx_j = torch.tensor([j for _, j in conflict_pairs], device=prop_emb.device)
pi = prop_emb[idx_i]
pj = prop_emb[idx_j]
dot = (pi * pj).sum(dim=1)
loss_conf = (dot ** 2).mean()
total = cfg.weight_dir * loss_dir + cfg.weight_len * loss_len + cfg.weight_conf * loss_conf
return {
"total": total,
"loss_dir": loss_dir,
"loss_len": loss_len,
"loss_conf": loss_conf,
}
class PropertyTrainer:
def __init__(
self,
node_encoder,
prop_encoder: PropertyEncoder,
device: torch.device,
lr: float = 1e-4,
freeze_node_encoder: bool = True,
):
self.device = device
self.node_encoder = node_encoder.to(device)
self.prop_encoder = prop_encoder.to(device)
if freeze_node_encoder:
for p in self.node_encoder.parameters():
p.requires_grad = False
params = list(self.prop_encoder.parameters())
if not freeze_node_encoder:
params += list(self.node_encoder.parameters())
self.optimizer = torch.optim.Adam(params, lr=lr)
def train_step(
self,
batch,
props: List[PropertyExpr],
loss_cfg: PropertyLossConfig,
seq_len: int,
max_pairs: int,
cache: "PropertyCache | None" = None,
) -> Dict[str, float]:
self.node_encoder.train()
self.prop_encoder.train()
node_emb = self.node_encoder(batch, seq_len)
prop_emb = self.prop_encoder.batch_encode(props, node_emb)
if cache is None:
var_truth = compute_var_truth(batch.sim_res, seq_len)
prop_truth = torch.stack([eval_property(p, var_truth) for p in props], dim=0)
imp_pairs, conf_pairs = build_relation_pairs(prop_truth, max_pairs=max_pairs)
else:
prop_truth, imp_pairs, conf_pairs = cache.get(batch, props, seq_len, max_pairs)
losses = compute_property_loss(prop_emb, imp_pairs, conf_pairs, loss_cfg)
self.optimizer.zero_grad()
losses["total"].backward()
self.optimizer.step()
return {k: v.item() for k, v in losses.items()}
def eval_step(
self,
batch,
props: List[PropertyExpr],
loss_cfg: PropertyLossConfig,
seq_len: int,
max_pairs: int,
cache: "PropertyCache | None" = None,
) -> Dict[str, float]:
self.node_encoder.eval()
self.prop_encoder.eval()
with torch.no_grad():
node_emb = self.node_encoder(batch, seq_len)
prop_emb = self.prop_encoder.batch_encode(props, node_emb)
if cache is None:
var_truth = compute_var_truth(batch.sim_res, seq_len)
prop_truth = torch.stack([eval_property(p, var_truth) for p in props], dim=0)
imp_pairs, conf_pairs = build_relation_pairs(prop_truth, max_pairs=max_pairs)
else:
prop_truth, imp_pairs, conf_pairs = cache.get(batch, props, seq_len, max_pairs)
losses = compute_property_loss(prop_emb, imp_pairs, conf_pairs, loss_cfg)
return {k: v.item() for k, v in losses.items()}
class PropertyCache:
def __init__(self, cache_dir: Path):
self.cache_dir = cache_dir
self.cache_dir.mkdir(parents=True, exist_ok=True)
@staticmethod
def _batch_id(batch) -> str:
name = getattr(batch, "name", None)
if isinstance(name, (list, tuple)):
if len(name) == 1:
name = name[0]
else:
name = "_".join(str(n) for n in name)
if not name:
name = f"graph_{id(batch)}"
safe = "".join(ch if ch.isalnum() or ch in ("_", "-") else "_" for ch in str(name))
return safe
@staticmethod
def _props_sig(props: List[PropertyExpr]) -> str:
raw = "|".join(p.to_string() for p in props).encode("utf-8")
return hashlib.md5(raw).hexdigest()
def get(
self,
batch,
props: List[PropertyExpr],
seq_len: int,
max_pairs: int,
) -> Tuple[torch.Tensor, List[Tuple[int, int]], List[Tuple[int, int]]]:
batch_id = self._batch_id(batch)
sig = self._props_sig(props)
cache_path = self.cache_dir / f"{batch_id}_len{seq_len}_pairs{max_pairs}_{sig}.pt"
if cache_path.exists():
data = torch.load(cache_path, map_location="cpu")
return data["prop_truth"], data["imp_pairs"], data["conf_pairs"]
var_truth = compute_var_truth(batch.sim_res, seq_len)
prop_truth = torch.stack([eval_property(p, var_truth) for p in props], dim=0)
imp_pairs, conf_pairs = build_relation_pairs(prop_truth, max_pairs=max_pairs)
torch.save({"prop_truth": prop_truth.cpu(), "imp_pairs": imp_pairs, "conf_pairs": conf_pairs}, cache_path)
return prop_truth, imp_pairs, conf_pairs
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
import torch
from torch_geometric.loader import DataLoader
from property_expr import generate_properties
from property_encoder import PropertyEncoder
from property_trainer import PropertyTrainer, PropertyLossConfig, PropertyCache
REPO_ROOT = Path(__file__).resolve().parent
MODEL_EXAMPLE_SRC = REPO_ROOT / "model_example" / "src"
def _select_device(device_arg: str) -> torch.device:
if device_arg.startswith("cuda") and torch.cuda.is_available():
return torch.device(device_arg)
return torch.device("cpu")
def _load_meta(data_dir: Path) -> dict:
meta_path = data_dir / "meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"meta.json not found in {data_dir}")
return json.loads(meta_path.read_text())
def _get_var_indices(meta: dict, signal_ops: list[str]) -> list[int]:
node_names = meta.get("node_names", [])
var_indices = []
for idx, name in enumerate(node_names):
op = name.split(",")[0] if name else ""
if op in signal_ops:
var_indices.append(idx)
return var_indices
def main() -> int:
parser = argparse.ArgumentParser(description="Property-alignment test")
parser.add_argument("--case", required=True, help="Case name under cases/")
parser.add_argument("--data_dir", default="", help="Dataset dir (default: results/<case>/dataset)")
parser.add_argument("--graph_npz_name", default="graphs.npz", type=str)
parser.add_argument("--label_npz_name", default="labels.npz", type=str)
parser.add_argument("--checkpoint", required=True, help="Property encoder checkpoint path")
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--num_rounds", default=20, type=int)
parser.add_argument("--model", default="default_shared", type=str)
parser.add_argument("--eval_seq_len", default=50, type=int)
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--max_props", default=64, type=int)
parser.add_argument("--max_pairs", default=512, type=int)
parser.add_argument("--signal_ops", default="Input,Output,Wire,Reg", type=str)
parser.add_argument("--include_ops", default="VAR,NOT,AND,OR,XOR,IMPLIES,EQUIV", type=str)
parser.add_argument("--cache_dir", default="", type=str)
parser.add_argument("--no_cache", action="store_true")
args = parser.parse_args()
data_dir = Path(args.data_dir) if args.data_dir else (REPO_ROOT / "results" / args.case / "dataset")
graph_npz = data_dir / args.graph_npz_name
label_npz = data_dir / args.label_npz_name
os.chdir(REPO_ROOT)
sys.path.insert(0, str(MODEL_EXAMPLE_SRC))
from npz_parser import NpzParser # noqa: E402
from model_arch import Model_default, Model_shared # noqa: E402
model_factory = {
"default_shared": Model_shared,
"default": Model_default,
}
if args.model not in model_factory:
raise ValueError(f"Model not supported: {args.model}")
device = _select_device(args.device)
meta = _load_meta(data_dir)
signal_ops = [s.strip() for s in args.signal_ops.split(",") if s.strip()]
var_indices = _get_var_indices(meta, signal_ops)
include_ops = tuple(s.strip() for s in args.include_ops.split(",") if s.strip())
props = generate_properties(var_indices, max_props=args.max_props, include_ops=include_ops)
dataset = NpzParser(str(data_dir), str(graph_npz), str(label_npz))
_, test_dataset = dataset.get_dataset(split_with_design=True)
if args.batch_size != 1:
raise ValueError("Property testing assumes batch_size=1 for correct indexing.")
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
node_encoder = model_factory[args.model](num_rounds=args.num_rounds)
prop_encoder = PropertyEncoder(node_dim=node_encoder.dim_hidden)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
prop_encoder.load_state_dict(checkpoint["state_dict"])
trainer = PropertyTrainer(node_encoder=node_encoder, prop_encoder=prop_encoder, device=device, freeze_node_encoder=True)
loss_cfg = PropertyLossConfig()
cache = None
if not args.no_cache:
cache_dir = Path(args.cache_dir) if args.cache_dir else (REPO_ROOT / "results" / args.case / "property" / "test_cache")
cache = PropertyCache(cache_dir)
test_losses = []
for batch in test_loader:
batch = batch.to(device)
losses = trainer.eval_step(batch, props, loss_cfg, args.eval_seq_len, args.max_pairs, cache=cache)
test_losses.append(losses)
def avg(key, items):
return sum(x[key] for x in items) / max(1, len(items))
log = {
"test_total": avg("total", test_losses),
"test_dir": avg("loss_dir", test_losses),
"test_len": avg("loss_len", test_losses),
"test_conf": avg("loss_conf", test_losses),
}
print(json.dumps(log))
return 0
if __name__ == "__main__":
raise SystemExit(main())
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from pathlib import Path
import torch
from torch_geometric.loader import DataLoader
from property_expr import generate_properties
from property_encoder import PropertyEncoder
from property_trainer import PropertyTrainer, PropertyLossConfig, PropertyCache
REPO_ROOT = Path(__file__).resolve().parent
MODEL_EXAMPLE_SRC = REPO_ROOT / "model_example" / "src"
def _select_device(device_arg: str) -> torch.device:
if device_arg.startswith("cuda") and torch.cuda.is_available():
return torch.device(device_arg)
return torch.device("cpu")
def _load_meta(data_dir: Path) -> dict:
meta_path = data_dir / "meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"meta.json not found in {data_dir}")
return json.loads(meta_path.read_text())
def _get_var_indices(meta: dict, signal_ops: list[str]) -> list[int]:
node_names = meta.get("node_names", [])
var_indices = []
for idx, name in enumerate(node_names):
op = name.split(",")[0] if name else ""
if op in signal_ops:
var_indices.append(idx)
return var_indices
def main() -> int:
parser = argparse.ArgumentParser(description="Property-alignment training")
parser.add_argument("--case", required=True, help="Case name under cases/")
parser.add_argument("--data_dir", default="", help="Dataset dir (default: results/<case>/dataset)")
parser.add_argument("--graph_npz_name", default="graphs.npz", type=str)
parser.add_argument("--label_npz_name", default="labels.npz", type=str)
parser.add_argument("--num_epochs", default=20, type=int)
parser.add_argument("--lr", default=1e-4, type=float)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--num_rounds", default=20, type=int)
parser.add_argument("--model", default="default_shared", type=str)
parser.add_argument("--train_seq_len", default=50, type=int)
parser.add_argument("--eval_seq_len", default=50, type=int)
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--max_props", default=64, type=int)
parser.add_argument("--max_pairs", default=512, type=int)
parser.add_argument("--freeze_node_encoder", action="store_true")
parser.add_argument("--signal_ops", default="Input,Output,Wire,Reg", type=str)
parser.add_argument("--include_ops", default="VAR,NOT,AND,OR,XOR,IMPLIES,EQUIV", type=str)
parser.add_argument("--cache_dir", default="", type=str)
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--exp_id", default="property", type=str)
parser.add_argument("--w_dir", default=1.0, type=float)
parser.add_argument("--w_len", default=1.0, type=float)
parser.add_argument("--w_conf", default=1.0, type=float)
parser.add_argument("--margin", default=0.1, type=float)
args = parser.parse_args()
data_dir = Path(args.data_dir) if args.data_dir else (REPO_ROOT / "results" / args.case / "dataset")
graph_npz = data_dir / args.graph_npz_name
label_npz = data_dir / args.label_npz_name
os.chdir(REPO_ROOT)
sys.path.insert(0, str(MODEL_EXAMPLE_SRC))
from npz_parser import NpzParser # noqa: E402
from model_arch import Model_default, Model_shared # noqa: E402
model_factory = {
"default_shared": Model_shared,
"default": Model_default,
}
if args.model not in model_factory:
raise ValueError(f"Model not supported: {args.model}")
device = _select_device(args.device)
meta = _load_meta(data_dir)
signal_ops = [s.strip() for s in args.signal_ops.split(",") if s.strip()]
var_indices = _get_var_indices(meta, signal_ops)
include_ops = tuple(s.strip() for s in args.include_ops.split(",") if s.strip())
props = generate_properties(var_indices, max_props=args.max_props, include_ops=include_ops)
dataset = NpzParser(str(data_dir), str(graph_npz), str(label_npz))
train_dataset, val_dataset = dataset.get_dataset(split_with_design=True)
if args.batch_size != 1:
raise ValueError("Property training assumes batch_size=1 for correct indexing.")
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
node_encoder = model_factory[args.model](num_rounds=args.num_rounds)
prop_encoder = PropertyEncoder(node_dim=node_encoder.dim_hidden)
trainer = PropertyTrainer(
node_encoder=node_encoder,
prop_encoder=prop_encoder,
device=device,
lr=args.lr,
freeze_node_encoder=args.freeze_node_encoder,
)
loss_cfg = PropertyLossConfig(weight_dir=args.w_dir, weight_len=args.w_len, weight_conf=args.w_conf, margin=args.margin)
exp_dir = REPO_ROOT / "results" / args.case / "property" / args.exp_id
exp_dir.mkdir(parents=True, exist_ok=True)
cache = None
if not args.no_cache:
cache_dir = Path(args.cache_dir) if args.cache_dir else (exp_dir / "cache")
cache = PropertyCache(cache_dir)
for epoch in range(args.num_epochs):
train_losses = []
for batch in train_loader:
batch = batch.to(device)
losses = trainer.train_step(batch, props, loss_cfg, args.train_seq_len, args.max_pairs, cache=cache)
train_losses.append(losses)
val_losses = []
for batch in val_loader:
batch = batch.to(device)
losses = trainer.eval_step(batch, props, loss_cfg, args.eval_seq_len, args.max_pairs, cache=cache)
val_losses.append(losses)
def avg(key, items):
return sum(x[key] for x in items) / max(1, len(items))
log = {
"epoch": epoch,
"train_total": avg("total", train_losses),
"val_total": avg("total", val_losses),
"train_dir": avg("loss_dir", train_losses),
"train_len": avg("loss_len", train_losses),
"train_conf": avg("loss_conf", train_losses),
"val_dir": avg("loss_dir", val_losses),
"val_len": avg("loss_len", val_losses),
"val_conf": avg("loss_conf", val_losses),
}
print(json.dumps(log))
ckpt_path = exp_dir / "property_encoder_last.pth"
torch.save({"state_dict": trainer.prop_encoder.state_dict()}, ckpt_path)
return 0
if __name__ == "__main__":
raise SystemExit(main())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment