Commit ce346a35 by lvzhengyang

impl. the train logic

parent d8d9890b
......@@ -38,6 +38,28 @@ Configure waveform parsing in `cases/<case>/config.toml` under `[waveform]` and
By default the dataset is saved to `results/<case>/dataset/graphs.npz` with metadata in `results/<case>/dataset/meta.json`.
### Stage-1 Training (DynamicRTL-style)
Stage-1 training expects operator-type node features and bitvector waveforms. For `eth_fifo`, configure `[dataset]` in `cases/eth_fifo/config.toml`:
```
node_type_mode = "category"
edge_type_mode = "numeric"
sim_res_mode = "bitvector"
sim_res_width = 32
build_labels = true
op_to_index_path = "operator_types/op_to_index.json"
category_map_path = "operator_types/category_map.json"
signal_ops = ["Input", "Output", "Wire", "Reg"]
```
Rebuild the dataset and launch training:
```bash
./main.py build_dataset --case eth_fifo
python3 train_stage1.py --case eth_fifo --model default_shared
```
### View CDFG / Dataset
```bash
......
......@@ -29,3 +29,11 @@ clock_edge = "posedge"
# output_dir = "results/eth_fifo/dataset"
# Optional override for CDFG dot if needed.
# cdfg_dot = "results/eth_fifo/cdfg/eth_fifo_ast_clean_cdfg.dot"
node_type_mode = "category"
edge_type_mode = "numeric"
sim_res_mode = "bitvector"
sim_res_width = 32
build_labels = true
op_to_index_path = "operator_types/op_to_index.json"
category_map_path = "operator_types/category_map.json"
signal_ops = ["Input", "Output", "Wire", "Reg"]
......@@ -26,6 +26,15 @@ def sanitize_name(name: str) -> str:
@dataclass
class NodeInfo:
full_name: str
op: str
width: int
base_name: str
attrs: Dict[str, str]
@dataclass
class GraphData:
node_names: List[str]
node_types: List[int]
......@@ -45,17 +54,44 @@ def _parse_attr_map(attr_text: str) -> Dict[str, str]:
return attrs
def parse_dot_graph(dot_path: Path) -> Tuple[List[str], Dict[str, Dict[str, str]], List[Tuple[str, str, str]]]:
node_attrs: Dict[str, Dict[str, str]] = {}
node_order: List[str] = []
def _parse_node_name(name: str) -> Tuple[str, int, str]:
parts = name.split(",")
op = parts[0] if parts else name
width = 0
base_name = name
if len(parts) >= 2:
width_str = parts[1].strip()
if width_str.isdigit():
width = int(width_str)
elif width_str.lower() == "null":
width = 0
else:
digits = "".join(ch for ch in width_str if ch.isdigit())
width = int(digits) if digits else 0
if len(parts) >= 3:
base_name = ",".join(parts[2:])
return op, width, base_name
def parse_dot_graph_with_meta(dot_path: Path) -> Tuple[List[NodeInfo], List[Tuple[str, str, str]]]:
node_infos: List[NodeInfo] = []
node_map: Dict[str, NodeInfo] = {}
edges: List[Tuple[str, str, str]] = []
def add_node(name: str, attrs: Optional[Dict[str, str]] = None) -> None:
if name not in node_attrs:
node_order.append(name)
node_attrs[name] = attrs or {}
if name not in node_map:
op, width, base_name = _parse_node_name(name)
info = NodeInfo(
full_name=name,
op=op,
width=width,
base_name=base_name,
attrs=attrs or {},
)
node_infos.append(info)
node_map[name] = info
elif attrs:
node_attrs[name].update(attrs)
node_map[name].attrs.update(attrs)
with dot_path.open("r") as f:
for raw in f:
......@@ -63,7 +99,6 @@ def parse_dot_graph(dot_path: Path) -> Tuple[List[str], Dict[str, Dict[str, str]
if not line or line.startswith("digraph") or line in {"{", "}"}:
continue
if "->" in line:
# Edge line
m = re.search(r'\"?(.+?)\"?\\s*->\\s*\"?(.+?)\"?\\s*\\[label=\"(.*?)\"\\]', line)
if not m:
continue
......@@ -75,7 +110,6 @@ def parse_dot_graph(dot_path: Path) -> Tuple[List[str], Dict[str, Dict[str, str]
edges.append((src, dst, label))
continue
# Node line
if "[" in line and "]" in line:
name_part, attr_part = line.split("[", 1)
name = name_part.strip().strip(";").strip().strip("\"")
......@@ -87,41 +121,102 @@ def parse_dot_graph(dot_path: Path) -> Tuple[List[str], Dict[str, Dict[str, str]
if name:
add_node(name)
return node_infos, edges
def parse_dot_graph(dot_path: Path) -> Tuple[List[str], Dict[str, Dict[str, str]], List[Tuple[str, str, str]]]:
node_infos, edges = parse_dot_graph_with_meta(dot_path)
node_order = [info.full_name for info in node_infos]
node_attrs = {info.full_name: info.attrs for info in node_infos}
return node_order, node_attrs, edges
def build_graph(dot_path: Path) -> GraphData:
node_order, node_attrs, edges = parse_dot_graph(dot_path)
def _edge_label_to_int(label: str) -> int:
digits = "".join(ch for ch in label if ch.isdigit())
return int(digits) if digits else 0
node_types: List[int] = []
for name in node_order:
attrs = node_attrs.get(name, {})
color = attrs.get("color")
node_types.append(COLOR_TO_TYPE.get(color, 0))
def _load_category_map(path: Path, ops: List[str]) -> Dict[str, str]:
if path.exists():
with path.open("r") as f:
category_map = json.load(f)
else:
category_map = {}
updated = False
for op in ops:
if op not in category_map:
category_map[op] = op
updated = True
if updated or not path.exists():
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w") as f:
json.dump(category_map, f, indent=2)
return category_map
def build_graph(
dot_path: Path,
node_type_mode: str = "color",
edge_type_mode: str = "label_map",
category_map_path: Optional[Path] = None,
) -> Tuple[GraphData, List[NodeInfo], Optional[Dict[str, int]], Optional[Dict[str, str]]]:
node_infos, edges = parse_dot_graph_with_meta(dot_path)
node_names = [info.full_name for info in node_infos]
node_idx = {name: idx for idx, name in enumerate(node_names)}
op_to_index: Optional[Dict[str, int]] = None
category_map: Optional[Dict[str, str]] = None
if node_type_mode == "category":
ops = sorted({info.op for info in node_infos})
if category_map_path is None:
category_map_path = Path("operator_types") / "category_map.json"
category_map = _load_category_map(category_map_path, ops)
categories = sorted(set(category_map.values()))
op_to_index = {cat: idx for idx, cat in enumerate(categories)}
node_types = []
for info in node_infos:
cat = category_map.get(info.op, info.op)
if cat not in op_to_index:
op_to_index[cat] = len(op_to_index)
node_types.append(op_to_index[cat])
elif node_type_mode == "op":
ops = sorted({info.op for info in node_infos})
op_to_index = {op: idx for idx, op in enumerate(ops)}
node_types = [op_to_index[info.op] for info in node_infos]
else:
node_types = []
for info in node_infos:
color = info.attrs.get("color")
node_types.append(COLOR_TO_TYPE.get(color, 0))
label_map: Dict[str, int] = {}
edge_index: List[List[int]] = [[], []]
edge_index: List[List[int]] = []
edge_type: List[int] = []
node_idx = {name: idx for idx, name in enumerate(node_order)}
for src, dst, label in edges:
if label not in label_map:
label_map[label] = len(label_map)
edge_index[0].append(node_idx[src])
edge_index[1].append(node_idx[dst])
edge_type.append(label_map[label])
if src not in node_idx or dst not in node_idx:
continue
if edge_type_mode == "numeric":
etype = _edge_label_to_int(label)
label_map.setdefault(label, etype)
else:
if label not in label_map:
label_map[label] = len(label_map)
etype = label_map[label]
edge_index.append([node_idx[src], node_idx[dst]])
edge_type.append(etype)
edge_index_arr = np.array(edge_index, dtype=np.int64)
edge_type_arr = np.array(edge_type, dtype=np.int64)
return GraphData(node_order, node_types, edge_index_arr, edge_type_arr, label_map)
graph = GraphData(node_names, node_types, edge_index_arr, edge_type_arr, label_map)
return graph, node_infos, op_to_index, category_map
def _parse_vcd_var_line(tokens: List[str]) -> Tuple[int, str, str]:
width = int(tokens[2])
var_id = tokens[3]
var_name = tokens[4]
if len(tokens) > 6 and tokens[5].startswith("["):
var_name = f"{var_name}{tokens[5]}"
return width, var_id, var_name
......@@ -135,61 +230,89 @@ def _value_to_int(value: str) -> int:
return int(value, 2)
def _int_to_bitvector(value: int, width: int, out_width: int) -> List[int]:
if out_width <= 0:
raise ValueError("sim_res_width must be positive")
if value < 0:
return [0] * out_width
if width <= 0:
width = out_width
mask = (1 << width) - 1 if width < 63 else (1 << out_width) - 1
value = value & mask
bits = [(value >> (width - 1 - i)) & 1 for i in range(width)]
if width < out_width:
bits = [0] * (out_width - width) + bits
elif width > out_width:
bits = bits[-out_width:]
return bits
def parse_vcd_trace(
vcd_path: Path,
node_names: List[str],
signal_map: Dict[str, List[int]],
node_widths: List[int],
clock: str,
reset: Optional[str],
reset_active: str,
scope: Optional[str] = None,
clock_edge: str = "posedge",
max_cycles: Optional[int] = None,
) -> Tuple[np.ndarray, Dict[str, int], List[str]]:
node_set = set(node_names)
sim_res_mode: str = "scalar",
sim_res_width: int = 32,
) -> Tuple[np.ndarray, Dict[int, int], set[int]]:
num_nodes = len(node_widths)
clock_san = sanitize_name(clock)
reset_san = sanitize_name(reset) if reset else None
id_to_signal: Dict[str, str] = {}
signal_widths: Dict[str, int] = {}
current_values: Dict[str, int] = {}
id_to_nodes: Dict[str, List[int]] = {}
signal_widths: Dict[int, int] = {}
mapped_nodes: set[int] = set()
current_values = [-1] * num_nodes
scope_stack: List[str] = []
clock_id: Optional[str] = None
reset_id: Optional[str] = None
clock_val = -1
reset_val = -1
def use_signal(rel_name: str, width: int, var_id: str) -> None:
nonlocal clock_id, reset_id
name_san = sanitize_name(rel_name)
if name_san in node_set or name_san in {clock_san, reset_san}:
id_to_signal[var_id] = name_san
signal_widths.setdefault(name_san, width)
current_values.setdefault(name_san, -1)
if name_san == clock_san:
clock_id = var_id
if reset_san and name_san == reset_san:
reset_id = var_id
in_defs = True
in_dumpvars = False
if name_san == clock_san:
clock_id = var_id
if reset_san and name_san == reset_san:
reset_id = var_id
if name_san in signal_map:
id_to_nodes[var_id] = signal_map[name_san]
for idx in signal_map[name_san]:
mapped_nodes.add(idx)
signal_widths.setdefault(idx, width)
last_clock: Optional[int] = None
pending_sample = False
samples_collected = 0
values_per_node = [[] for _ in node_names]
values_per_node: List[list] = [[] for _ in range(num_nodes)]
def record_sample() -> None:
nonlocal samples_collected
if max_cycles is not None and samples_collected >= max_cycles:
return
if reset_id is not None:
reset_val = current_values.get(reset_san, -1)
if reset_val == -1:
return
if reset_active == "high" and reset_val == 1:
return
if reset_active == "low" and reset_val == 0:
return
for idx, name in enumerate(node_names):
values_per_node[idx].append(current_values.get(name, -1))
if sim_res_mode == "scalar":
for idx in range(num_nodes):
values_per_node[idx].append(current_values[idx])
elif sim_res_mode == "bitvector":
for idx in range(num_nodes):
width = node_widths[idx] if node_widths[idx] > 0 else sim_res_width
values_per_node[idx].append(_int_to_bitvector(current_values[idx], width, sim_res_width))
else:
raise ValueError(f"Unsupported sim_res_mode: {sim_res_mode}")
samples_collected += 1
with vcd_path.open("r") as f:
......@@ -221,13 +344,8 @@ def parse_vcd_trace(
use_signal(rel_name, width, var_id)
continue
if line.startswith("$enddefinitions"):
in_defs = False
continue
if line.startswith("$dumpvars"):
in_dumpvars = True
continue
if in_dumpvars and line.startswith("$end"):
in_dumpvars = False
continue
if line.startswith("#"):
if pending_sample:
......@@ -237,7 +355,6 @@ def parse_vcd_trace(
break
continue
# Value change
if line[0] in "01xXzZ":
value = line[0]
var_id = line[1:].strip()
......@@ -250,12 +367,17 @@ def parse_vcd_trace(
else:
continue
if var_id not in id_to_signal:
continue
signal_name = id_to_signal[var_id]
value_int = _value_to_int(value)
current_values[signal_name] = value_int
if var_id == clock_id:
clock_val = value_int
if var_id == reset_id:
reset_val = value_int
if var_id not in id_to_nodes:
continue
for idx in id_to_nodes[var_id]:
current_values[idx] = value_int
if var_id == clock_id and value_int != -1:
if last_clock is None:
......@@ -271,7 +393,6 @@ def parse_vcd_trace(
record_sample()
trace = np.array(values_per_node, dtype=np.int64)
mapped_nodes = [name for name in node_names if name in signal_widths]
return trace, signal_widths, mapped_nodes
......@@ -284,6 +405,15 @@ def build_dataset(
dataset_cfg = config.get("dataset", {})
wave_cfg = config.get("waveform", {})
node_type_mode = dataset_cfg.get("node_type_mode", "color")
edge_type_mode = dataset_cfg.get("edge_type_mode", "label_map")
sim_res_mode = dataset_cfg.get("sim_res_mode", "scalar")
sim_res_width = int(dataset_cfg.get("sim_res_width", 32))
signal_ops = dataset_cfg.get("signal_ops")
op_to_index_path = dataset_cfg.get("op_to_index_path")
category_map_path_cfg = dataset_cfg.get("category_map_path")
build_labels = bool(dataset_cfg.get("build_labels", False))
cdfg_output_dir = cdfg_cfg.get("output_dir")
if cdfg_output_dir:
cdfg_dir = (repo_root / cdfg_output_dir).resolve() if not Path(cdfg_output_dir).is_absolute() else Path(cdfg_output_dir)
......@@ -302,7 +432,17 @@ def build_dataset(
raise FileNotFoundError(f"No .dot files found in {cdfg_dir}")
dot_path = dot_files[0]
graph = build_graph(dot_path)
category_map_path = None
if category_map_path_cfg:
category_map_path = Path(category_map_path_cfg)
if not category_map_path.is_absolute():
category_map_path = (repo_root / category_map_path).resolve()
graph, node_infos, op_to_index, category_map = build_graph(
dot_path,
node_type_mode=node_type_mode,
edge_type_mode=edge_type_mode,
category_map_path=category_map_path,
)
vcd_glob = wave_cfg.get("vcd_glob")
vcd_files = wave_cfg.get("vcd_files")
......@@ -333,33 +473,59 @@ def build_dataset(
raise KeyError("Missing waveform.clock in config.toml")
if reset_active not in {"high", "low"}:
raise ValueError("waveform.reset_active must be 'high' or 'low'")
if sim_res_mode == "bitvector" and sim_res_width <= 0:
raise ValueError("dataset.sim_res_width must be positive for bitvector mode")
if signal_ops is None:
signal_ops_set = {"Input", "Output", "Wire", "Reg"}
else:
signal_ops_set = {str(op) for op in signal_ops}
signal_map: Dict[str, List[int]] = {}
node_widths: List[int] = []
for idx, info in enumerate(node_infos):
width = info.width if info.width > 0 else 1
node_widths.append(width)
if info.op not in signal_ops_set:
continue
base = sanitize_name(info.base_name)
if not base:
continue
signal_map.setdefault(base, []).append(idx)
sim_res: List[np.ndarray] = []
signal_widths: Dict[str, int] = {}
mapped_nodes: set[str] = set()
signal_widths: Dict[int, int] = {}
mapped_nodes: set[int] = set()
trace_names: List[str] = []
for vcd_path in vcd_paths:
trace, widths, mapped = parse_vcd_trace(
vcd_path=vcd_path,
node_names=graph.node_names,
signal_map=signal_map,
node_widths=node_widths,
clock=clock,
reset=reset,
reset_active=reset_active,
scope=scope,
clock_edge=clock_edge,
max_cycles=max_cycles,
sim_res_mode=sim_res_mode,
sim_res_width=sim_res_width,
)
sim_res.append(trace)
signal_widths.update(widths)
mapped_nodes.update(mapped)
trace_names.append(str(vcd_path))
has_sim_res = np.array([1 if name in mapped_nodes else 0 for name in graph.node_names], dtype=np.int64)
node_widths = np.array([signal_widths.get(name, 1) for name in graph.node_names], dtype=np.int64)
for idx, width in signal_widths.items():
if node_widths[idx] <= 1 and width > 1:
node_widths[idx] = width
has_sim_res = np.array([1 if idx in mapped_nodes else 0 for idx in range(len(graph.node_names))], dtype=np.int64)
node_ids = np.arange(len(graph.node_names), dtype=np.int64)
node_types = np.array(graph.node_types, dtype=np.int64)
x = np.stack([node_ids, node_types, node_widths], axis=1)
node_widths_arr = np.array(node_widths, dtype=np.int64)
x = np.stack([node_ids, node_types, node_widths_arr], axis=1)
designs = {
case_name: {
......@@ -383,11 +549,25 @@ def build_dataset(
output_npz = output_dir / "graphs.npz"
np.savez_compressed(output_npz, designs=designs)
if build_labels:
labels = {name: {"y": 0} for name in designs}
labels_npz = output_dir / "labels.npz"
np.savez_compressed(labels_npz, labels=labels)
if op_to_index and op_to_index_path:
op_path = Path(op_to_index_path)
if not op_path.is_absolute():
op_path = (repo_root / op_path).resolve()
op_path.parent.mkdir(parents=True, exist_ok=True)
with op_path.open("w") as f:
json.dump(op_to_index, f, indent=2)
meta = {
"case": case_name,
"cdfg_dot": str(dot_path),
"vcd_files": trace_names,
"node_names": graph.node_names,
"node_base_names": [info.base_name for info in node_infos],
"edge_label_map": graph.edge_label_map,
"clock": clock,
"reset": reset,
......@@ -395,7 +575,20 @@ def build_dataset(
"scope": scope,
"clock_edge": clock_edge,
"max_cycles": max_cycles,
"node_type_mode": node_type_mode,
"edge_type_mode": edge_type_mode,
"sim_res_mode": sim_res_mode,
"sim_res_width": sim_res_width,
"signal_ops": sorted(signal_ops_set),
}
if op_to_index:
meta["op_to_index"] = op_to_index
if op_to_index_path:
meta["op_to_index_path"] = op_to_index_path
if category_map:
meta["category_map"] = category_map
if category_map_path_cfg:
meta["category_map_path"] = category_map_path_cfg
with (output_dir / "meta.json").open("w") as f:
json.dump(meta, f, indent=2)
......
{
"Add": "Add",
"BitAnd": "BitAnd",
"BitXor": "BitXor",
"Concat": "Concat",
"Cond": "Cond",
"Cond_If": "Cond_If",
"Const": "Const",
"Eq": "Eq",
"Input": "Input",
"Not": "Not",
"Output": "Output",
"PartSelect": "PartSelect",
"Reg": "Reg",
"Sub": "Sub",
"URand": "URand",
"URor": "URor",
"Wire": "Wire"
}
\ No newline at end of file
{
"Add": 0,
"BitAnd": 1,
"BitXor": 2,
"Concat": 3,
"Cond": 4,
"Cond_If": 5,
"Const": 6,
"Eq": 7,
"Input": 8,
"Not": 9,
"Output": 10,
"PartSelect": 11,
"Reg": 12,
"Sub": 13,
"URand": 14,
"URor": 15,
"Wire": 16
}
\ No newline at end of file
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from pathlib import Path
import numpy as np
import torch
REPO_ROOT = Path(__file__).resolve().parent
MODEL_EXAMPLE_SRC = REPO_ROOT / "model_example" / "src"
def _ensure_op_to_index(data_dir: Path) -> None:
meta_path = data_dir / "meta.json"
op_to_index = None
op_path = REPO_ROOT / "operator_types" / "op_to_index.json"
if meta_path.exists():
meta = json.loads(meta_path.read_text())
op_to_index = meta.get("op_to_index")
op_path_cfg = meta.get("op_to_index_path")
if op_path_cfg:
cfg_path = Path(op_path_cfg)
if not cfg_path.is_absolute():
cfg_path = (REPO_ROOT / cfg_path).resolve()
op_path = cfg_path
if op_to_index is None:
node_names = meta.get("node_names", [])
ops = sorted({name.split(",")[0] for name in node_names if name})
if ops:
op_to_index = {op: idx for idx, op in enumerate(ops)}
if op_to_index is None:
print("[train_stage1] Warning: op_to_index not found; training may fail.")
return
op_path.parent.mkdir(parents=True, exist_ok=True)
op_path.write_text(json.dumps(op_to_index, indent=2))
print(f"[train_stage1] op_to_index saved to {op_path}")
required_keys = {"Input", "Const", "Wire", "Reg", "Cond", "Output"}
missing = required_keys - set(op_to_index.keys())
if missing:
print(f"[train_stage1] Warning: op_to_index missing keys {sorted(missing)}. "
"If you merged categories, update the model or mapping accordingly.")
def _ensure_labels(data_dir: Path, graph_npz: Path, label_npz: Path) -> None:
if label_npz.exists():
return
if not graph_npz.exists():
raise FileNotFoundError(f"graphs.npz not found: {graph_npz}")
with np.load(graph_npz, allow_pickle=True) as data:
designs = data["designs"].item()
labels = {name: {"y": 0} for name in designs}
label_npz.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(label_npz, labels=labels)
print(f"[train_stage1] labels.npz created at {label_npz}")
def _select_device(device_arg: str) -> torch.device:
if device_arg.startswith("cuda") and torch.cuda.is_available():
return torch.device(device_arg)
return torch.device("cpu")
def main() -> int:
parser = argparse.ArgumentParser(description="Stage-1 training (DynamicRTL style)")
parser.add_argument("--case", required=True, help="Case name under cases/")
parser.add_argument("--data_dir", default="", help="Dataset dir (default: results/<case>/dataset)")
parser.add_argument("--graph_npz_name", default="graphs.npz", type=str)
parser.add_argument("--label_npz_name", default="labels.npz", type=str)
parser.add_argument("--distributed", action="store_true", help="If set, train in distributed mode")
parser.add_argument("--num_workers", default=4, type=int)
parser.add_argument("--batch_size", default=64, type=int)
parser.add_argument("--lr", default=1e-4, type=float)
parser.add_argument("--lr_step", default=50, type=int)
parser.add_argument("--num_epochs", default=60, type=int)
parser.add_argument("--num_rounds", default=20, type=int, help="Number of rounds to GNN propagate")
parser.add_argument("--train_seq_len", default=50, type=int)
parser.add_argument("--eval_seq_len", default=50, type=int)
parser.add_argument("--device", default="cuda", type=str, help="cpu / cuda / cuda:0")
parser.add_argument("--gpus", default="0", type=str, help="GPU IDs to use, example: 0,1,2,3")
parser.add_argument("--model", default="default_shared", type=str, help="default_shared or default")
parser.add_argument("--exp_id", default="stage1", type=str, help="Experiment ID")
parser.add_argument("--supervision", default="default", type=str, help="only_branch/only_tgl/default")
parser.add_argument("--split_with_design", action="store_true", help="Split dataset by design name")
args = parser.parse_args()
data_dir = Path(args.data_dir) if args.data_dir else (REPO_ROOT / "results" / args.case / "dataset")
graph_npz = data_dir / args.graph_npz_name
label_npz = data_dir / args.label_npz_name
os.chdir(REPO_ROOT)
_ensure_op_to_index(data_dir)
_ensure_labels(data_dir, graph_npz, label_npz)
sys.path.insert(0, str(MODEL_EXAMPLE_SRC))
from npz_parser import NpzParser # noqa: E402
from model_arch import Model_default, Model_shared # noqa: E402
from trainer import Trainer # noqa: E402
model_factory = {
"default_shared": Model_shared,
"default": Model_default,
}
if args.model not in model_factory:
raise ValueError(f"Model not supported: {args.model}")
if args.model == "default":
print("[train_stage1] Warning: Model_default assumes fixed operator indices. Prefer --model default_shared.")
device = _select_device(args.device)
print(f"[train_stage1] Using device: {device}")
dataset = NpzParser(str(data_dir), str(graph_npz), str(label_npz))
train_dataset, val_dataset = dataset.get_dataset(split_with_design=args.split_with_design)
model = model_factory[args.model](num_rounds=args.num_rounds)
time_str = time.strftime("%Y-%m-%d-%H-%M")
trainer = Trainer(
args,
model,
distributed=args.distributed,
batch_size=args.batch_size,
device=device,
gpus=args.gpus,
training_id=time_str,
)
trainer.set_training_args(lr=args.lr, lr_step=args.lr_step)
print("[train_stage1] Stage 1 Training ...")
trainer.train(
args.num_epochs,
train_dataset,
val_dataset,
train_seq_len=args.train_seq_len,
eval_seq_len=args.eval_seq_len,
supervision=args.supervision,
)
print("[train_stage1] Finish Training")
return 0
if __name__ == "__main__":
raise SystemExit(main())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment