Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
P
PropertyEmb
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
lvzhengyang
PropertyEmb
Commits
ce346a35
Commit
ce346a35
authored
Feb 28, 2026
by
lvzhengyang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
impl. the train logic
parent
d8d9890b
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
481 additions
and
66 deletions
+481
-66
README
+22
-0
cases/eth_fifo/config.toml
+8
-0
dataset_builder.py
+259
-66
operator_types/category_map.json
+20
-0
operator_types/op_to_index.json
+20
-0
train_stage1.py
+152
-0
No files found.
README
View file @
ce346a35
...
@@ -38,6 +38,28 @@ Configure waveform parsing in `cases/<case>/config.toml` under `[waveform]` and
...
@@ -38,6 +38,28 @@ Configure waveform parsing in `cases/<case>/config.toml` under `[waveform]` and
By default the dataset is saved to `results/<case>/dataset/graphs.npz` with metadata in `results/<case>/dataset/meta.json`.
By default the dataset is saved to `results/<case>/dataset/graphs.npz` with metadata in `results/<case>/dataset/meta.json`.
### Stage-1 Training (DynamicRTL-style)
Stage-1 training expects operator-type node features and bitvector waveforms. For `eth_fifo`, configure `[dataset]` in `cases/eth_fifo/config.toml`:
```
node_type_mode = "category"
edge_type_mode = "numeric"
sim_res_mode = "bitvector"
sim_res_width = 32
build_labels = true
op_to_index_path = "operator_types/op_to_index.json"
category_map_path = "operator_types/category_map.json"
signal_ops = ["Input", "Output", "Wire", "Reg"]
```
Rebuild the dataset and launch training:
```bash
./main.py build_dataset --case eth_fifo
python3 train_stage1.py --case eth_fifo --model default_shared
```
### View CDFG / Dataset
### View CDFG / Dataset
```bash
```bash
...
...
cases/eth_fifo/config.toml
View file @
ce346a35
...
@@ -29,3 +29,11 @@ clock_edge = "posedge"
...
@@ -29,3 +29,11 @@ clock_edge = "posedge"
# output_dir = "results/eth_fifo/dataset"
# output_dir = "results/eth_fifo/dataset"
# Optional override for CDFG dot if needed.
# Optional override for CDFG dot if needed.
# cdfg_dot = "results/eth_fifo/cdfg/eth_fifo_ast_clean_cdfg.dot"
# cdfg_dot = "results/eth_fifo/cdfg/eth_fifo_ast_clean_cdfg.dot"
node_type_mode
=
"category"
edge_type_mode
=
"numeric"
sim_res_mode
=
"bitvector"
sim_res_width
=
32
build_labels
=
true
op_to_index_path
=
"operator_types/op_to_index.json"
category_map_path
=
"operator_types/category_map.json"
signal_ops
=
[
"Input"
,
"Output"
,
"Wire"
,
"Reg"
]
dataset_builder.py
View file @
ce346a35
...
@@ -26,6 +26,15 @@ def sanitize_name(name: str) -> str:
...
@@ -26,6 +26,15 @@ def sanitize_name(name: str) -> str:
@dataclass
@dataclass
class
NodeInfo
:
full_name
:
str
op
:
str
width
:
int
base_name
:
str
attrs
:
Dict
[
str
,
str
]
@dataclass
class
GraphData
:
class
GraphData
:
node_names
:
List
[
str
]
node_names
:
List
[
str
]
node_types
:
List
[
int
]
node_types
:
List
[
int
]
...
@@ -45,17 +54,44 @@ def _parse_attr_map(attr_text: str) -> Dict[str, str]:
...
@@ -45,17 +54,44 @@ def _parse_attr_map(attr_text: str) -> Dict[str, str]:
return
attrs
return
attrs
def
parse_dot_graph
(
dot_path
:
Path
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
Dict
[
str
,
str
]],
List
[
Tuple
[
str
,
str
,
str
]]]:
def
_parse_node_name
(
name
:
str
)
->
Tuple
[
str
,
int
,
str
]:
node_attrs
:
Dict
[
str
,
Dict
[
str
,
str
]]
=
{}
parts
=
name
.
split
(
","
)
node_order
:
List
[
str
]
=
[]
op
=
parts
[
0
]
if
parts
else
name
width
=
0
base_name
=
name
if
len
(
parts
)
>=
2
:
width_str
=
parts
[
1
]
.
strip
()
if
width_str
.
isdigit
():
width
=
int
(
width_str
)
elif
width_str
.
lower
()
==
"null"
:
width
=
0
else
:
digits
=
""
.
join
(
ch
for
ch
in
width_str
if
ch
.
isdigit
())
width
=
int
(
digits
)
if
digits
else
0
if
len
(
parts
)
>=
3
:
base_name
=
","
.
join
(
parts
[
2
:])
return
op
,
width
,
base_name
def
parse_dot_graph_with_meta
(
dot_path
:
Path
)
->
Tuple
[
List
[
NodeInfo
],
List
[
Tuple
[
str
,
str
,
str
]]]:
node_infos
:
List
[
NodeInfo
]
=
[]
node_map
:
Dict
[
str
,
NodeInfo
]
=
{}
edges
:
List
[
Tuple
[
str
,
str
,
str
]]
=
[]
edges
:
List
[
Tuple
[
str
,
str
,
str
]]
=
[]
def
add_node
(
name
:
str
,
attrs
:
Optional
[
Dict
[
str
,
str
]]
=
None
)
->
None
:
def
add_node
(
name
:
str
,
attrs
:
Optional
[
Dict
[
str
,
str
]]
=
None
)
->
None
:
if
name
not
in
node_attrs
:
if
name
not
in
node_map
:
node_order
.
append
(
name
)
op
,
width
,
base_name
=
_parse_node_name
(
name
)
node_attrs
[
name
]
=
attrs
or
{}
info
=
NodeInfo
(
full_name
=
name
,
op
=
op
,
width
=
width
,
base_name
=
base_name
,
attrs
=
attrs
or
{},
)
node_infos
.
append
(
info
)
node_map
[
name
]
=
info
elif
attrs
:
elif
attrs
:
node_
attrs
[
name
]
.
update
(
attrs
)
node_
map
[
name
]
.
attrs
.
update
(
attrs
)
with
dot_path
.
open
(
"r"
)
as
f
:
with
dot_path
.
open
(
"r"
)
as
f
:
for
raw
in
f
:
for
raw
in
f
:
...
@@ -63,7 +99,6 @@ def parse_dot_graph(dot_path: Path) -> Tuple[List[str], Dict[str, Dict[str, str]
...
@@ -63,7 +99,6 @@ def parse_dot_graph(dot_path: Path) -> Tuple[List[str], Dict[str, Dict[str, str]
if
not
line
or
line
.
startswith
(
"digraph"
)
or
line
in
{
"{"
,
"}"
}:
if
not
line
or
line
.
startswith
(
"digraph"
)
or
line
in
{
"{"
,
"}"
}:
continue
continue
if
"->"
in
line
:
if
"->"
in
line
:
# Edge line
m
=
re
.
search
(
r'\"?(.+?)\"?\\s*->\\s*\"?(.+?)\"?\\s*\\[label=\"(.*?)\"\\]'
,
line
)
m
=
re
.
search
(
r'\"?(.+?)\"?\\s*->\\s*\"?(.+?)\"?\\s*\\[label=\"(.*?)\"\\]'
,
line
)
if
not
m
:
if
not
m
:
continue
continue
...
@@ -75,7 +110,6 @@ def parse_dot_graph(dot_path: Path) -> Tuple[List[str], Dict[str, Dict[str, str]
...
@@ -75,7 +110,6 @@ def parse_dot_graph(dot_path: Path) -> Tuple[List[str], Dict[str, Dict[str, str]
edges
.
append
((
src
,
dst
,
label
))
edges
.
append
((
src
,
dst
,
label
))
continue
continue
# Node line
if
"["
in
line
and
"]"
in
line
:
if
"["
in
line
and
"]"
in
line
:
name_part
,
attr_part
=
line
.
split
(
"["
,
1
)
name_part
,
attr_part
=
line
.
split
(
"["
,
1
)
name
=
name_part
.
strip
()
.
strip
(
";"
)
.
strip
()
.
strip
(
"
\"
"
)
name
=
name_part
.
strip
()
.
strip
(
";"
)
.
strip
()
.
strip
(
"
\"
"
)
...
@@ -87,41 +121,102 @@ def parse_dot_graph(dot_path: Path) -> Tuple[List[str], Dict[str, Dict[str, str]
...
@@ -87,41 +121,102 @@ def parse_dot_graph(dot_path: Path) -> Tuple[List[str], Dict[str, Dict[str, str]
if
name
:
if
name
:
add_node
(
name
)
add_node
(
name
)
return
node_infos
,
edges
def
parse_dot_graph
(
dot_path
:
Path
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
Dict
[
str
,
str
]],
List
[
Tuple
[
str
,
str
,
str
]]]:
node_infos
,
edges
=
parse_dot_graph_with_meta
(
dot_path
)
node_order
=
[
info
.
full_name
for
info
in
node_infos
]
node_attrs
=
{
info
.
full_name
:
info
.
attrs
for
info
in
node_infos
}
return
node_order
,
node_attrs
,
edges
return
node_order
,
node_attrs
,
edges
def
build_graph
(
dot_path
:
Path
)
->
GraphData
:
def
_edge_label_to_int
(
label
:
str
)
->
int
:
node_order
,
node_attrs
,
edges
=
parse_dot_graph
(
dot_path
)
digits
=
""
.
join
(
ch
for
ch
in
label
if
ch
.
isdigit
())
return
int
(
digits
)
if
digits
else
0
node_types
:
List
[
int
]
=
[]
for
name
in
node_order
:
def
_load_category_map
(
path
:
Path
,
ops
:
List
[
str
])
->
Dict
[
str
,
str
]:
attrs
=
node_attrs
.
get
(
name
,
{})
if
path
.
exists
():
color
=
attrs
.
get
(
"color"
)
with
path
.
open
(
"r"
)
as
f
:
node_types
.
append
(
COLOR_TO_TYPE
.
get
(
color
,
0
))
category_map
=
json
.
load
(
f
)
else
:
category_map
=
{}
updated
=
False
for
op
in
ops
:
if
op
not
in
category_map
:
category_map
[
op
]
=
op
updated
=
True
if
updated
or
not
path
.
exists
():
path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
with
path
.
open
(
"w"
)
as
f
:
json
.
dump
(
category_map
,
f
,
indent
=
2
)
return
category_map
def
build_graph
(
dot_path
:
Path
,
node_type_mode
:
str
=
"color"
,
edge_type_mode
:
str
=
"label_map"
,
category_map_path
:
Optional
[
Path
]
=
None
,
)
->
Tuple
[
GraphData
,
List
[
NodeInfo
],
Optional
[
Dict
[
str
,
int
]],
Optional
[
Dict
[
str
,
str
]]]:
node_infos
,
edges
=
parse_dot_graph_with_meta
(
dot_path
)
node_names
=
[
info
.
full_name
for
info
in
node_infos
]
node_idx
=
{
name
:
idx
for
idx
,
name
in
enumerate
(
node_names
)}
op_to_index
:
Optional
[
Dict
[
str
,
int
]]
=
None
category_map
:
Optional
[
Dict
[
str
,
str
]]
=
None
if
node_type_mode
==
"category"
:
ops
=
sorted
({
info
.
op
for
info
in
node_infos
})
if
category_map_path
is
None
:
category_map_path
=
Path
(
"operator_types"
)
/
"category_map.json"
category_map
=
_load_category_map
(
category_map_path
,
ops
)
categories
=
sorted
(
set
(
category_map
.
values
()))
op_to_index
=
{
cat
:
idx
for
idx
,
cat
in
enumerate
(
categories
)}
node_types
=
[]
for
info
in
node_infos
:
cat
=
category_map
.
get
(
info
.
op
,
info
.
op
)
if
cat
not
in
op_to_index
:
op_to_index
[
cat
]
=
len
(
op_to_index
)
node_types
.
append
(
op_to_index
[
cat
])
elif
node_type_mode
==
"op"
:
ops
=
sorted
({
info
.
op
for
info
in
node_infos
})
op_to_index
=
{
op
:
idx
for
idx
,
op
in
enumerate
(
ops
)}
node_types
=
[
op_to_index
[
info
.
op
]
for
info
in
node_infos
]
else
:
node_types
=
[]
for
info
in
node_infos
:
color
=
info
.
attrs
.
get
(
"color"
)
node_types
.
append
(
COLOR_TO_TYPE
.
get
(
color
,
0
))
label_map
:
Dict
[
str
,
int
]
=
{}
label_map
:
Dict
[
str
,
int
]
=
{}
edge_index
:
List
[
List
[
int
]]
=
[
[],
[]
]
edge_index
:
List
[
List
[
int
]]
=
[]
edge_type
:
List
[
int
]
=
[]
edge_type
:
List
[
int
]
=
[]
node_idx
=
{
name
:
idx
for
idx
,
name
in
enumerate
(
node_order
)}
for
src
,
dst
,
label
in
edges
:
for
src
,
dst
,
label
in
edges
:
if
label
not
in
label_map
:
if
src
not
in
node_idx
or
dst
not
in
node_idx
:
label_map
[
label
]
=
len
(
label_map
)
continue
edge_index
[
0
]
.
append
(
node_idx
[
src
])
if
edge_type_mode
==
"numeric"
:
edge_index
[
1
]
.
append
(
node_idx
[
dst
])
etype
=
_edge_label_to_int
(
label
)
edge_type
.
append
(
label_map
[
label
])
label_map
.
setdefault
(
label
,
etype
)
else
:
if
label
not
in
label_map
:
label_map
[
label
]
=
len
(
label_map
)
etype
=
label_map
[
label
]
edge_index
.
append
([
node_idx
[
src
],
node_idx
[
dst
]])
edge_type
.
append
(
etype
)
edge_index_arr
=
np
.
array
(
edge_index
,
dtype
=
np
.
int64
)
edge_index_arr
=
np
.
array
(
edge_index
,
dtype
=
np
.
int64
)
edge_type_arr
=
np
.
array
(
edge_type
,
dtype
=
np
.
int64
)
edge_type_arr
=
np
.
array
(
edge_type
,
dtype
=
np
.
int64
)
return
GraphData
(
node_order
,
node_types
,
edge_index_arr
,
edge_type_arr
,
label_map
)
graph
=
GraphData
(
node_names
,
node_types
,
edge_index_arr
,
edge_type_arr
,
label_map
)
return
graph
,
node_infos
,
op_to_index
,
category_map
def
_parse_vcd_var_line
(
tokens
:
List
[
str
])
->
Tuple
[
int
,
str
,
str
]:
def
_parse_vcd_var_line
(
tokens
:
List
[
str
])
->
Tuple
[
int
,
str
,
str
]:
width
=
int
(
tokens
[
2
])
width
=
int
(
tokens
[
2
])
var_id
=
tokens
[
3
]
var_id
=
tokens
[
3
]
var_name
=
tokens
[
4
]
var_name
=
tokens
[
4
]
if
len
(
tokens
)
>
6
and
tokens
[
5
]
.
startswith
(
"["
):
var_name
=
f
"{var_name}{tokens[5]}"
return
width
,
var_id
,
var_name
return
width
,
var_id
,
var_name
...
@@ -135,61 +230,89 @@ def _value_to_int(value: str) -> int:
...
@@ -135,61 +230,89 @@ def _value_to_int(value: str) -> int:
return
int
(
value
,
2
)
return
int
(
value
,
2
)
def
_int_to_bitvector
(
value
:
int
,
width
:
int
,
out_width
:
int
)
->
List
[
int
]:
if
out_width
<=
0
:
raise
ValueError
(
"sim_res_width must be positive"
)
if
value
<
0
:
return
[
0
]
*
out_width
if
width
<=
0
:
width
=
out_width
mask
=
(
1
<<
width
)
-
1
if
width
<
63
else
(
1
<<
out_width
)
-
1
value
=
value
&
mask
bits
=
[(
value
>>
(
width
-
1
-
i
))
&
1
for
i
in
range
(
width
)]
if
width
<
out_width
:
bits
=
[
0
]
*
(
out_width
-
width
)
+
bits
elif
width
>
out_width
:
bits
=
bits
[
-
out_width
:]
return
bits
def
parse_vcd_trace
(
def
parse_vcd_trace
(
vcd_path
:
Path
,
vcd_path
:
Path
,
node_names
:
List
[
str
],
signal_map
:
Dict
[
str
,
List
[
int
]],
node_widths
:
List
[
int
],
clock
:
str
,
clock
:
str
,
reset
:
Optional
[
str
],
reset
:
Optional
[
str
],
reset_active
:
str
,
reset_active
:
str
,
scope
:
Optional
[
str
]
=
None
,
scope
:
Optional
[
str
]
=
None
,
clock_edge
:
str
=
"posedge"
,
clock_edge
:
str
=
"posedge"
,
max_cycles
:
Optional
[
int
]
=
None
,
max_cycles
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
np
.
ndarray
,
Dict
[
str
,
int
],
List
[
str
]]:
sim_res_mode
:
str
=
"scalar"
,
node_set
=
set
(
node_names
)
sim_res_width
:
int
=
32
,
)
->
Tuple
[
np
.
ndarray
,
Dict
[
int
,
int
],
set
[
int
]]:
num_nodes
=
len
(
node_widths
)
clock_san
=
sanitize_name
(
clock
)
clock_san
=
sanitize_name
(
clock
)
reset_san
=
sanitize_name
(
reset
)
if
reset
else
None
reset_san
=
sanitize_name
(
reset
)
if
reset
else
None
id_to_signal
:
Dict
[
str
,
str
]
=
{}
id_to_nodes
:
Dict
[
str
,
List
[
int
]]
=
{}
signal_widths
:
Dict
[
str
,
int
]
=
{}
signal_widths
:
Dict
[
int
,
int
]
=
{}
current_values
:
Dict
[
str
,
int
]
=
{}
mapped_nodes
:
set
[
int
]
=
set
()
current_values
=
[
-
1
]
*
num_nodes
scope_stack
:
List
[
str
]
=
[]
scope_stack
:
List
[
str
]
=
[]
clock_id
:
Optional
[
str
]
=
None
clock_id
:
Optional
[
str
]
=
None
reset_id
:
Optional
[
str
]
=
None
reset_id
:
Optional
[
str
]
=
None
clock_val
=
-
1
reset_val
=
-
1
def
use_signal
(
rel_name
:
str
,
width
:
int
,
var_id
:
str
)
->
None
:
def
use_signal
(
rel_name
:
str
,
width
:
int
,
var_id
:
str
)
->
None
:
nonlocal
clock_id
,
reset_id
nonlocal
clock_id
,
reset_id
name_san
=
sanitize_name
(
rel_name
)
name_san
=
sanitize_name
(
rel_name
)
if
name_san
in
node_set
or
name_san
in
{
clock_san
,
reset_san
}:
if
name_san
==
clock_san
:
id_to_signal
[
var_id
]
=
name_san
clock_id
=
var_id
signal_widths
.
setdefault
(
name_san
,
width
)
if
reset_san
and
name_san
==
reset_san
:
current_values
.
setdefault
(
name_san
,
-
1
)
reset_id
=
var_id
if
name_san
==
clock_san
:
if
name_san
in
signal_map
:
clock_id
=
var_id
id_to_nodes
[
var_id
]
=
signal_map
[
name_san
]
if
reset_san
and
name_san
==
reset_san
:
for
idx
in
signal_map
[
name_san
]:
reset_id
=
var_id
mapped_nodes
.
add
(
idx
)
signal_widths
.
setdefault
(
idx
,
width
)
in_defs
=
True
in_dumpvars
=
False
last_clock
:
Optional
[
int
]
=
None
last_clock
:
Optional
[
int
]
=
None
pending_sample
=
False
pending_sample
=
False
samples_collected
=
0
samples_collected
=
0
values_per_node
=
[[]
for
_
in
node_names
]
values_per_node
:
List
[
list
]
=
[[]
for
_
in
range
(
num_nodes
)
]
def
record_sample
()
->
None
:
def
record_sample
()
->
None
:
nonlocal
samples_collected
nonlocal
samples_collected
if
max_cycles
is
not
None
and
samples_collected
>=
max_cycles
:
if
max_cycles
is
not
None
and
samples_collected
>=
max_cycles
:
return
return
if
reset_id
is
not
None
:
if
reset_id
is
not
None
:
reset_val
=
current_values
.
get
(
reset_san
,
-
1
)
if
reset_val
==
-
1
:
if
reset_val
==
-
1
:
return
return
if
reset_active
==
"high"
and
reset_val
==
1
:
if
reset_active
==
"high"
and
reset_val
==
1
:
return
return
if
reset_active
==
"low"
and
reset_val
==
0
:
if
reset_active
==
"low"
and
reset_val
==
0
:
return
return
for
idx
,
name
in
enumerate
(
node_names
):
if
sim_res_mode
==
"scalar"
:
values_per_node
[
idx
]
.
append
(
current_values
.
get
(
name
,
-
1
))
for
idx
in
range
(
num_nodes
):
values_per_node
[
idx
]
.
append
(
current_values
[
idx
])
elif
sim_res_mode
==
"bitvector"
:
for
idx
in
range
(
num_nodes
):
width
=
node_widths
[
idx
]
if
node_widths
[
idx
]
>
0
else
sim_res_width
values_per_node
[
idx
]
.
append
(
_int_to_bitvector
(
current_values
[
idx
],
width
,
sim_res_width
))
else
:
raise
ValueError
(
f
"Unsupported sim_res_mode: {sim_res_mode}"
)
samples_collected
+=
1
samples_collected
+=
1
with
vcd_path
.
open
(
"r"
)
as
f
:
with
vcd_path
.
open
(
"r"
)
as
f
:
...
@@ -221,13 +344,8 @@ def parse_vcd_trace(
...
@@ -221,13 +344,8 @@ def parse_vcd_trace(
use_signal
(
rel_name
,
width
,
var_id
)
use_signal
(
rel_name
,
width
,
var_id
)
continue
continue
if
line
.
startswith
(
"$enddefinitions"
):
if
line
.
startswith
(
"$enddefinitions"
):
in_defs
=
False
continue
continue
if
line
.
startswith
(
"$dumpvars"
):
if
line
.
startswith
(
"$dumpvars"
):
in_dumpvars
=
True
continue
if
in_dumpvars
and
line
.
startswith
(
"$end"
):
in_dumpvars
=
False
continue
continue
if
line
.
startswith
(
"#"
):
if
line
.
startswith
(
"#"
):
if
pending_sample
:
if
pending_sample
:
...
@@ -237,7 +355,6 @@ def parse_vcd_trace(
...
@@ -237,7 +355,6 @@ def parse_vcd_trace(
break
break
continue
continue
# Value change
if
line
[
0
]
in
"01xXzZ"
:
if
line
[
0
]
in
"01xXzZ"
:
value
=
line
[
0
]
value
=
line
[
0
]
var_id
=
line
[
1
:]
.
strip
()
var_id
=
line
[
1
:]
.
strip
()
...
@@ -250,12 +367,17 @@ def parse_vcd_trace(
...
@@ -250,12 +367,17 @@ def parse_vcd_trace(
else
:
else
:
continue
continue
if
var_id
not
in
id_to_signal
:
continue
signal_name
=
id_to_signal
[
var_id
]
value_int
=
_value_to_int
(
value
)
value_int
=
_value_to_int
(
value
)
current_values
[
signal_name
]
=
value_int
if
var_id
==
clock_id
:
clock_val
=
value_int
if
var_id
==
reset_id
:
reset_val
=
value_int
if
var_id
not
in
id_to_nodes
:
continue
for
idx
in
id_to_nodes
[
var_id
]:
current_values
[
idx
]
=
value_int
if
var_id
==
clock_id
and
value_int
!=
-
1
:
if
var_id
==
clock_id
and
value_int
!=
-
1
:
if
last_clock
is
None
:
if
last_clock
is
None
:
...
@@ -271,7 +393,6 @@ def parse_vcd_trace(
...
@@ -271,7 +393,6 @@ def parse_vcd_trace(
record_sample
()
record_sample
()
trace
=
np
.
array
(
values_per_node
,
dtype
=
np
.
int64
)
trace
=
np
.
array
(
values_per_node
,
dtype
=
np
.
int64
)
mapped_nodes
=
[
name
for
name
in
node_names
if
name
in
signal_widths
]
return
trace
,
signal_widths
,
mapped_nodes
return
trace
,
signal_widths
,
mapped_nodes
...
@@ -284,6 +405,15 @@ def build_dataset(
...
@@ -284,6 +405,15 @@ def build_dataset(
dataset_cfg
=
config
.
get
(
"dataset"
,
{})
dataset_cfg
=
config
.
get
(
"dataset"
,
{})
wave_cfg
=
config
.
get
(
"waveform"
,
{})
wave_cfg
=
config
.
get
(
"waveform"
,
{})
node_type_mode
=
dataset_cfg
.
get
(
"node_type_mode"
,
"color"
)
edge_type_mode
=
dataset_cfg
.
get
(
"edge_type_mode"
,
"label_map"
)
sim_res_mode
=
dataset_cfg
.
get
(
"sim_res_mode"
,
"scalar"
)
sim_res_width
=
int
(
dataset_cfg
.
get
(
"sim_res_width"
,
32
))
signal_ops
=
dataset_cfg
.
get
(
"signal_ops"
)
op_to_index_path
=
dataset_cfg
.
get
(
"op_to_index_path"
)
category_map_path_cfg
=
dataset_cfg
.
get
(
"category_map_path"
)
build_labels
=
bool
(
dataset_cfg
.
get
(
"build_labels"
,
False
))
cdfg_output_dir
=
cdfg_cfg
.
get
(
"output_dir"
)
cdfg_output_dir
=
cdfg_cfg
.
get
(
"output_dir"
)
if
cdfg_output_dir
:
if
cdfg_output_dir
:
cdfg_dir
=
(
repo_root
/
cdfg_output_dir
)
.
resolve
()
if
not
Path
(
cdfg_output_dir
)
.
is_absolute
()
else
Path
(
cdfg_output_dir
)
cdfg_dir
=
(
repo_root
/
cdfg_output_dir
)
.
resolve
()
if
not
Path
(
cdfg_output_dir
)
.
is_absolute
()
else
Path
(
cdfg_output_dir
)
...
@@ -302,7 +432,17 @@ def build_dataset(
...
@@ -302,7 +432,17 @@ def build_dataset(
raise
FileNotFoundError
(
f
"No .dot files found in {cdfg_dir}"
)
raise
FileNotFoundError
(
f
"No .dot files found in {cdfg_dir}"
)
dot_path
=
dot_files
[
0
]
dot_path
=
dot_files
[
0
]
graph
=
build_graph
(
dot_path
)
category_map_path
=
None
if
category_map_path_cfg
:
category_map_path
=
Path
(
category_map_path_cfg
)
if
not
category_map_path
.
is_absolute
():
category_map_path
=
(
repo_root
/
category_map_path
)
.
resolve
()
graph
,
node_infos
,
op_to_index
,
category_map
=
build_graph
(
dot_path
,
node_type_mode
=
node_type_mode
,
edge_type_mode
=
edge_type_mode
,
category_map_path
=
category_map_path
,
)
vcd_glob
=
wave_cfg
.
get
(
"vcd_glob"
)
vcd_glob
=
wave_cfg
.
get
(
"vcd_glob"
)
vcd_files
=
wave_cfg
.
get
(
"vcd_files"
)
vcd_files
=
wave_cfg
.
get
(
"vcd_files"
)
...
@@ -333,33 +473,59 @@ def build_dataset(
...
@@ -333,33 +473,59 @@ def build_dataset(
raise
KeyError
(
"Missing waveform.clock in config.toml"
)
raise
KeyError
(
"Missing waveform.clock in config.toml"
)
if
reset_active
not
in
{
"high"
,
"low"
}:
if
reset_active
not
in
{
"high"
,
"low"
}:
raise
ValueError
(
"waveform.reset_active must be 'high' or 'low'"
)
raise
ValueError
(
"waveform.reset_active must be 'high' or 'low'"
)
if
sim_res_mode
==
"bitvector"
and
sim_res_width
<=
0
:
raise
ValueError
(
"dataset.sim_res_width must be positive for bitvector mode"
)
if
signal_ops
is
None
:
signal_ops_set
=
{
"Input"
,
"Output"
,
"Wire"
,
"Reg"
}
else
:
signal_ops_set
=
{
str
(
op
)
for
op
in
signal_ops
}
signal_map
:
Dict
[
str
,
List
[
int
]]
=
{}
node_widths
:
List
[
int
]
=
[]
for
idx
,
info
in
enumerate
(
node_infos
):
width
=
info
.
width
if
info
.
width
>
0
else
1
node_widths
.
append
(
width
)
if
info
.
op
not
in
signal_ops_set
:
continue
base
=
sanitize_name
(
info
.
base_name
)
if
not
base
:
continue
signal_map
.
setdefault
(
base
,
[])
.
append
(
idx
)
sim_res
:
List
[
np
.
ndarray
]
=
[]
sim_res
:
List
[
np
.
ndarray
]
=
[]
signal_widths
:
Dict
[
str
,
int
]
=
{}
signal_widths
:
Dict
[
int
,
int
]
=
{}
mapped_nodes
:
set
[
str
]
=
set
()
mapped_nodes
:
set
[
int
]
=
set
()
trace_names
:
List
[
str
]
=
[]
trace_names
:
List
[
str
]
=
[]
for
vcd_path
in
vcd_paths
:
for
vcd_path
in
vcd_paths
:
trace
,
widths
,
mapped
=
parse_vcd_trace
(
trace
,
widths
,
mapped
=
parse_vcd_trace
(
vcd_path
=
vcd_path
,
vcd_path
=
vcd_path
,
node_names
=
graph
.
node_names
,
signal_map
=
signal_map
,
node_widths
=
node_widths
,
clock
=
clock
,
clock
=
clock
,
reset
=
reset
,
reset
=
reset
,
reset_active
=
reset_active
,
reset_active
=
reset_active
,
scope
=
scope
,
scope
=
scope
,
clock_edge
=
clock_edge
,
clock_edge
=
clock_edge
,
max_cycles
=
max_cycles
,
max_cycles
=
max_cycles
,
sim_res_mode
=
sim_res_mode
,
sim_res_width
=
sim_res_width
,
)
)
sim_res
.
append
(
trace
)
sim_res
.
append
(
trace
)
signal_widths
.
update
(
widths
)
signal_widths
.
update
(
widths
)
mapped_nodes
.
update
(
mapped
)
mapped_nodes
.
update
(
mapped
)
trace_names
.
append
(
str
(
vcd_path
))
trace_names
.
append
(
str
(
vcd_path
))
has_sim_res
=
np
.
array
([
1
if
name
in
mapped_nodes
else
0
for
name
in
graph
.
node_names
],
dtype
=
np
.
int64
)
for
idx
,
width
in
signal_widths
.
items
():
node_widths
=
np
.
array
([
signal_widths
.
get
(
name
,
1
)
for
name
in
graph
.
node_names
],
dtype
=
np
.
int64
)
if
node_widths
[
idx
]
<=
1
and
width
>
1
:
node_widths
[
idx
]
=
width
has_sim_res
=
np
.
array
([
1
if
idx
in
mapped_nodes
else
0
for
idx
in
range
(
len
(
graph
.
node_names
))],
dtype
=
np
.
int64
)
node_ids
=
np
.
arange
(
len
(
graph
.
node_names
),
dtype
=
np
.
int64
)
node_ids
=
np
.
arange
(
len
(
graph
.
node_names
),
dtype
=
np
.
int64
)
node_types
=
np
.
array
(
graph
.
node_types
,
dtype
=
np
.
int64
)
node_types
=
np
.
array
(
graph
.
node_types
,
dtype
=
np
.
int64
)
x
=
np
.
stack
([
node_ids
,
node_types
,
node_widths
],
axis
=
1
)
node_widths_arr
=
np
.
array
(
node_widths
,
dtype
=
np
.
int64
)
x
=
np
.
stack
([
node_ids
,
node_types
,
node_widths_arr
],
axis
=
1
)
designs
=
{
designs
=
{
case_name
:
{
case_name
:
{
...
@@ -383,11 +549,25 @@ def build_dataset(
...
@@ -383,11 +549,25 @@ def build_dataset(
output_npz
=
output_dir
/
"graphs.npz"
output_npz
=
output_dir
/
"graphs.npz"
np
.
savez_compressed
(
output_npz
,
designs
=
designs
)
np
.
savez_compressed
(
output_npz
,
designs
=
designs
)
if
build_labels
:
labels
=
{
name
:
{
"y"
:
0
}
for
name
in
designs
}
labels_npz
=
output_dir
/
"labels.npz"
np
.
savez_compressed
(
labels_npz
,
labels
=
labels
)
if
op_to_index
and
op_to_index_path
:
op_path
=
Path
(
op_to_index_path
)
if
not
op_path
.
is_absolute
():
op_path
=
(
repo_root
/
op_path
)
.
resolve
()
op_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
with
op_path
.
open
(
"w"
)
as
f
:
json
.
dump
(
op_to_index
,
f
,
indent
=
2
)
meta
=
{
meta
=
{
"case"
:
case_name
,
"case"
:
case_name
,
"cdfg_dot"
:
str
(
dot_path
),
"cdfg_dot"
:
str
(
dot_path
),
"vcd_files"
:
trace_names
,
"vcd_files"
:
trace_names
,
"node_names"
:
graph
.
node_names
,
"node_names"
:
graph
.
node_names
,
"node_base_names"
:
[
info
.
base_name
for
info
in
node_infos
],
"edge_label_map"
:
graph
.
edge_label_map
,
"edge_label_map"
:
graph
.
edge_label_map
,
"clock"
:
clock
,
"clock"
:
clock
,
"reset"
:
reset
,
"reset"
:
reset
,
...
@@ -395,7 +575,20 @@ def build_dataset(
...
@@ -395,7 +575,20 @@ def build_dataset(
"scope"
:
scope
,
"scope"
:
scope
,
"clock_edge"
:
clock_edge
,
"clock_edge"
:
clock_edge
,
"max_cycles"
:
max_cycles
,
"max_cycles"
:
max_cycles
,
"node_type_mode"
:
node_type_mode
,
"edge_type_mode"
:
edge_type_mode
,
"sim_res_mode"
:
sim_res_mode
,
"sim_res_width"
:
sim_res_width
,
"signal_ops"
:
sorted
(
signal_ops_set
),
}
}
if
op_to_index
:
meta
[
"op_to_index"
]
=
op_to_index
if
op_to_index_path
:
meta
[
"op_to_index_path"
]
=
op_to_index_path
if
category_map
:
meta
[
"category_map"
]
=
category_map
if
category_map_path_cfg
:
meta
[
"category_map_path"
]
=
category_map_path_cfg
with
(
output_dir
/
"meta.json"
)
.
open
(
"w"
)
as
f
:
with
(
output_dir
/
"meta.json"
)
.
open
(
"w"
)
as
f
:
json
.
dump
(
meta
,
f
,
indent
=
2
)
json
.
dump
(
meta
,
f
,
indent
=
2
)
...
...
operator_types/category_map.json
0 → 100644
View file @
ce346a35
{
"Add"
:
"Add"
,
"BitAnd"
:
"BitAnd"
,
"BitXor"
:
"BitXor"
,
"Concat"
:
"Concat"
,
"Cond"
:
"Cond"
,
"Cond_If"
:
"Cond_If"
,
"Const"
:
"Const"
,
"Eq"
:
"Eq"
,
"Input"
:
"Input"
,
"Not"
:
"Not"
,
"Output"
:
"Output"
,
"PartSelect"
:
"PartSelect"
,
"Reg"
:
"Reg"
,
"Sub"
:
"Sub"
,
"URand"
:
"URand"
,
"URor"
:
"URor"
,
"Wire"
:
"Wire"
}
\ No newline at end of file
operator_types/op_to_index.json
0 → 100644
View file @
ce346a35
{
"Add"
:
0
,
"BitAnd"
:
1
,
"BitXor"
:
2
,
"Concat"
:
3
,
"Cond"
:
4
,
"Cond_If"
:
5
,
"Const"
:
6
,
"Eq"
:
7
,
"Input"
:
8
,
"Not"
:
9
,
"Output"
:
10
,
"PartSelect"
:
11
,
"Reg"
:
12
,
"Sub"
:
13
,
"URand"
:
14
,
"URor"
:
15
,
"Wire"
:
16
}
\ No newline at end of file
train_stage1.py
0 → 100644
View file @
ce346a35
#!/usr/bin/env python3
from
__future__
import
annotations
import
argparse
import
json
import
os
import
sys
import
time
from
pathlib
import
Path
import
numpy
as
np
import
torch
REPO_ROOT
=
Path
(
__file__
)
.
resolve
()
.
parent
MODEL_EXAMPLE_SRC
=
REPO_ROOT
/
"model_example"
/
"src"
def
_ensure_op_to_index
(
data_dir
:
Path
)
->
None
:
meta_path
=
data_dir
/
"meta.json"
op_to_index
=
None
op_path
=
REPO_ROOT
/
"operator_types"
/
"op_to_index.json"
if
meta_path
.
exists
():
meta
=
json
.
loads
(
meta_path
.
read_text
())
op_to_index
=
meta
.
get
(
"op_to_index"
)
op_path_cfg
=
meta
.
get
(
"op_to_index_path"
)
if
op_path_cfg
:
cfg_path
=
Path
(
op_path_cfg
)
if
not
cfg_path
.
is_absolute
():
cfg_path
=
(
REPO_ROOT
/
cfg_path
)
.
resolve
()
op_path
=
cfg_path
if
op_to_index
is
None
:
node_names
=
meta
.
get
(
"node_names"
,
[])
ops
=
sorted
({
name
.
split
(
","
)[
0
]
for
name
in
node_names
if
name
})
if
ops
:
op_to_index
=
{
op
:
idx
for
idx
,
op
in
enumerate
(
ops
)}
if
op_to_index
is
None
:
print
(
"[train_stage1] Warning: op_to_index not found; training may fail."
)
return
op_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
op_path
.
write_text
(
json
.
dumps
(
op_to_index
,
indent
=
2
))
print
(
f
"[train_stage1] op_to_index saved to {op_path}"
)
required_keys
=
{
"Input"
,
"Const"
,
"Wire"
,
"Reg"
,
"Cond"
,
"Output"
}
missing
=
required_keys
-
set
(
op_to_index
.
keys
())
if
missing
:
print
(
f
"[train_stage1] Warning: op_to_index missing keys {sorted(missing)}. "
"If you merged categories, update the model or mapping accordingly."
)
def
_ensure_labels
(
data_dir
:
Path
,
graph_npz
:
Path
,
label_npz
:
Path
)
->
None
:
if
label_npz
.
exists
():
return
if
not
graph_npz
.
exists
():
raise
FileNotFoundError
(
f
"graphs.npz not found: {graph_npz}"
)
with
np
.
load
(
graph_npz
,
allow_pickle
=
True
)
as
data
:
designs
=
data
[
"designs"
]
.
item
()
labels
=
{
name
:
{
"y"
:
0
}
for
name
in
designs
}
label_npz
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
np
.
savez_compressed
(
label_npz
,
labels
=
labels
)
print
(
f
"[train_stage1] labels.npz created at {label_npz}"
)
def
_select_device
(
device_arg
:
str
)
->
torch
.
device
:
if
device_arg
.
startswith
(
"cuda"
)
and
torch
.
cuda
.
is_available
():
return
torch
.
device
(
device_arg
)
return
torch
.
device
(
"cpu"
)
def
main
()
->
int
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Stage-1 training (DynamicRTL style)"
)
parser
.
add_argument
(
"--case"
,
required
=
True
,
help
=
"Case name under cases/"
)
parser
.
add_argument
(
"--data_dir"
,
default
=
""
,
help
=
"Dataset dir (default: results/<case>/dataset)"
)
parser
.
add_argument
(
"--graph_npz_name"
,
default
=
"graphs.npz"
,
type
=
str
)
parser
.
add_argument
(
"--label_npz_name"
,
default
=
"labels.npz"
,
type
=
str
)
parser
.
add_argument
(
"--distributed"
,
action
=
"store_true"
,
help
=
"If set, train in distributed mode"
)
parser
.
add_argument
(
"--num_workers"
,
default
=
4
,
type
=
int
)
parser
.
add_argument
(
"--batch_size"
,
default
=
64
,
type
=
int
)
parser
.
add_argument
(
"--lr"
,
default
=
1e-4
,
type
=
float
)
parser
.
add_argument
(
"--lr_step"
,
default
=
50
,
type
=
int
)
parser
.
add_argument
(
"--num_epochs"
,
default
=
60
,
type
=
int
)
parser
.
add_argument
(
"--num_rounds"
,
default
=
20
,
type
=
int
,
help
=
"Number of rounds to GNN propagate"
)
parser
.
add_argument
(
"--train_seq_len"
,
default
=
50
,
type
=
int
)
parser
.
add_argument
(
"--eval_seq_len"
,
default
=
50
,
type
=
int
)
parser
.
add_argument
(
"--device"
,
default
=
"cuda"
,
type
=
str
,
help
=
"cpu / cuda / cuda:0"
)
parser
.
add_argument
(
"--gpus"
,
default
=
"0"
,
type
=
str
,
help
=
"GPU IDs to use, example: 0,1,2,3"
)
parser
.
add_argument
(
"--model"
,
default
=
"default_shared"
,
type
=
str
,
help
=
"default_shared or default"
)
parser
.
add_argument
(
"--exp_id"
,
default
=
"stage1"
,
type
=
str
,
help
=
"Experiment ID"
)
parser
.
add_argument
(
"--supervision"
,
default
=
"default"
,
type
=
str
,
help
=
"only_branch/only_tgl/default"
)
parser
.
add_argument
(
"--split_with_design"
,
action
=
"store_true"
,
help
=
"Split dataset by design name"
)
args
=
parser
.
parse_args
()
data_dir
=
Path
(
args
.
data_dir
)
if
args
.
data_dir
else
(
REPO_ROOT
/
"results"
/
args
.
case
/
"dataset"
)
graph_npz
=
data_dir
/
args
.
graph_npz_name
label_npz
=
data_dir
/
args
.
label_npz_name
os
.
chdir
(
REPO_ROOT
)
_ensure_op_to_index
(
data_dir
)
_ensure_labels
(
data_dir
,
graph_npz
,
label_npz
)
sys
.
path
.
insert
(
0
,
str
(
MODEL_EXAMPLE_SRC
))
from
npz_parser
import
NpzParser
# noqa: E402
from
model_arch
import
Model_default
,
Model_shared
# noqa: E402
from
trainer
import
Trainer
# noqa: E402
model_factory
=
{
"default_shared"
:
Model_shared
,
"default"
:
Model_default
,
}
if
args
.
model
not
in
model_factory
:
raise
ValueError
(
f
"Model not supported: {args.model}"
)
if
args
.
model
==
"default"
:
print
(
"[train_stage1] Warning: Model_default assumes fixed operator indices. Prefer --model default_shared."
)
device
=
_select_device
(
args
.
device
)
print
(
f
"[train_stage1] Using device: {device}"
)
dataset
=
NpzParser
(
str
(
data_dir
),
str
(
graph_npz
),
str
(
label_npz
))
train_dataset
,
val_dataset
=
dataset
.
get_dataset
(
split_with_design
=
args
.
split_with_design
)
model
=
model_factory
[
args
.
model
](
num_rounds
=
args
.
num_rounds
)
time_str
=
time
.
strftime
(
"
%
Y-
%
m-
%
d-
%
H-
%
M"
)
trainer
=
Trainer
(
args
,
model
,
distributed
=
args
.
distributed
,
batch_size
=
args
.
batch_size
,
device
=
device
,
gpus
=
args
.
gpus
,
training_id
=
time_str
,
)
trainer
.
set_training_args
(
lr
=
args
.
lr
,
lr_step
=
args
.
lr_step
)
print
(
"[train_stage1] Stage 1 Training ..."
)
trainer
.
train
(
args
.
num_epochs
,
train_dataset
,
val_dataset
,
train_seq_len
=
args
.
train_seq_len
,
eval_seq_len
=
args
.
eval_seq_len
,
supervision
=
args
.
supervision
,
)
print
(
"[train_stage1] Finish Training"
)
return
0
if
__name__
==
"__main__"
:
raise
SystemExit
(
main
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment