Commit 7789080d by Dinple

observation extractor test done

parent 3c1362f8
# coding=utf-8
# Copyright 2021 The Circuit Training Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A class to store the observation shape and sizes."""
from typing import Dict, List, Optional, Text, Tuple, Union
import gin
import gym
import numpy as np
import tensorflow as tf
TensorType = Union[np.ndarray, tf.Tensor]
FeatureKeyType = Union[List[Text], Tuple[Text, ...]]
HARD_MACRO = 1
SOFT_MACRO = 2
PORT_CLUSTER = 3
NETLIST_METADATA = (
'normalized_num_edges',
'normalized_num_hard_macros',
'normalized_num_soft_macros',
'normalized_num_port_clusters',
'horizontal_routes_per_micron',
'vertical_routes_per_micron',
'macro_horizontal_routing_allocation',
'macro_vertical_routing_allocation',
'grid_cols',
'grid_rows',
)
GRAPH_ADJACENCY_MATRIX = ('sparse_adj_i', 'sparse_adj_j', 'sparse_adj_weight',
'edge_counts')
NODE_STATIC_FEATURES = (
'macros_w',
'macros_h',
'node_types',
)
STATIC_OBSERVATIONS = (
NETLIST_METADATA + GRAPH_ADJACENCY_MATRIX + NODE_STATIC_FEATURES)
INITIAL_DYNAMIC_OBSERVATIONS = (
'locations_x',
'locations_y',
'is_node_placed',
)
DYNAMIC_OBSERVATIONS = (
'locations_x',
'locations_y',
'is_node_placed',
'current_node',
'mask',
)
ALL_OBSERVATIONS = STATIC_OBSERVATIONS + DYNAMIC_OBSERVATIONS
INITIAL_OBSERVATIONS = STATIC_OBSERVATIONS + INITIAL_DYNAMIC_OBSERVATIONS
@gin.configurable
class ObservationConfig(object):
"""A class that contains shared configs for observation."""
# The default numbers are the maximum number of nodes, edges, and grid size
# on a set of TPU blocks.
# Large numbers may cause GPU/TPU OOM during training.
def __init__(self,
max_num_nodes: int = 5000,
max_num_edges: int = 28400,
max_grid_size: int = 128):
self.max_num_edges = max_num_edges
self.max_num_nodes = max_num_nodes
self.max_grid_size = max_grid_size
@property
def observation_space(self) -> gym.spaces.Space:
"""Env Observation space."""
return gym.spaces.Dict({
'normalized_num_edges':
gym.spaces.Box(low=0, high=1, shape=(1,)),
'normalized_num_hard_macros':
gym.spaces.Box(low=0, high=1, shape=(1,)),
'normalized_num_soft_macros':
gym.spaces.Box(low=0, high=1, shape=(1,)),
'normalized_num_port_clusters':
gym.spaces.Box(low=0, high=1, shape=(1,)),
'horizontal_routes_per_micron':
gym.spaces.Box(low=0, high=100, shape=(1,)),
'vertical_routes_per_micron':
gym.spaces.Box(low=0, high=100, shape=(1,)),
'macro_horizontal_routing_allocation':
gym.spaces.Box(low=0, high=100, shape=(1,)),
'macro_vertical_routing_allocation':
gym.spaces.Box(low=0, high=100, shape=(1,)),
'sparse_adj_weight':
gym.spaces.Box(low=0, high=100, shape=(self.max_num_edges,)),
'sparse_adj_i':
gym.spaces.Box(
low=0,
high=self.max_num_nodes - 1,
shape=(self.max_num_edges,),
dtype=np.int32),
'sparse_adj_j':
gym.spaces.Box(
low=0,
high=self.max_num_nodes - 1,
shape=(self.max_num_edges,),
dtype=np.int32),
'edge_counts':
gym.spaces.Box(
low=0,
high=self.max_num_edges - 1,
shape=(self.max_num_nodes,),
dtype=np.int32),
'node_types':
gym.spaces.Box(
low=0, high=3, shape=(self.max_num_nodes,), dtype=np.int32),
'is_node_placed':
gym.spaces.Box(
low=0, high=1, shape=(self.max_num_nodes,), dtype=np.int32),
'macros_w':
gym.spaces.Box(low=0, high=1, shape=(self.max_num_nodes,)),
'macros_h':
gym.spaces.Box(low=0, high=1, shape=(self.max_num_nodes,)),
'locations_x':
gym.spaces.Box(low=0, high=1, shape=(self.max_num_nodes,)),
'locations_y':
gym.spaces.Box(low=0, high=1, shape=(self.max_num_nodes,)),
'grid_cols':
gym.spaces.Box(low=0, high=1, shape=(1,)),
'grid_rows':
gym.spaces.Box(low=0, high=1, shape=(1,)),
'current_node':
gym.spaces.Box(
low=0, high=self.max_num_nodes - 1, shape=(1,), dtype=np.int32),
'mask':
gym.spaces.Box(
low=0, high=1, shape=(self.max_grid_size**2,), dtype=np.int32),
})
def _to_dict(
flatten_obs: TensorType,
keys: FeatureKeyType,
observation_config: Optional[ObservationConfig] = None
) -> Dict[Text, TensorType]:
"""Unflatten the observation to a dictionary."""
if observation_config:
obs_space = observation_config.observation_space
else:
obs_space = ObservationConfig().observation_space
splits = [obs_space[k].shape[0] for k in keys]
splitted_obs = tf.split(flatten_obs, splits, axis=-1)
return {k: o for o, k in zip(splitted_obs, keys)}
def _flatten(dict_obs: Dict[Text, TensorType],
keys: FeatureKeyType) -> TensorType:
out = [np.asarray(dict_obs[k]) for k in keys]
return np.concatenate(out, axis=-1)
def flatten_static(dict_obs: Dict[Text, TensorType]) -> TensorType:
return _flatten(dict_obs=dict_obs, keys=STATIC_OBSERVATIONS)
def flatten_dynamic(dict_obs: Dict[Text, TensorType]) -> TensorType:
return _flatten(dict_obs=dict_obs, keys=DYNAMIC_OBSERVATIONS)
def flatten_all(dict_obs: Dict[Text, TensorType]) -> TensorType:
return _flatten(dict_obs=dict_obs, keys=ALL_OBSERVATIONS)
def flatten_initial(dict_obs: Dict[Text, TensorType]) -> TensorType:
return _flatten(dict_obs=dict_obs, keys=INITIAL_OBSERVATIONS)
def to_dict_static(
flatten_obs: TensorType,
observation_config: Optional[ObservationConfig] = None
) -> Dict[Text, TensorType]:
"""Convert the flattend numpy array of static observations back to a dict.
Args:
flatten_obs: a numpy array of static observations.
observation_config: Optional observation config.
Returns:
A dict representation of the observations.
"""
return _to_dict(
flatten_obs=flatten_obs,
keys=STATIC_OBSERVATIONS,
observation_config=observation_config)
def to_dict_dynamic(
flatten_obs: TensorType,
observation_config: Optional[ObservationConfig] = None
) -> Dict[Text, TensorType]:
"""Convert the flattend numpy array of dynamic observations back to a dict.
Args:
flatten_obs: a numpy array of dynamic observations.
observation_config: Optional observation config.
Returns:
A dict representation of the observations.
"""
return _to_dict(
flatten_obs=flatten_obs,
keys=DYNAMIC_OBSERVATIONS,
observation_config=observation_config)
def to_dict_all(
flatten_obs: TensorType,
observation_config: Optional[ObservationConfig] = None
) -> Dict[Text, TensorType]:
"""Convert the flattend numpy array of observations back to a dict.
Args:
flatten_obs: a numpy array of observations.
observation_config: Optional observation config.
Returns:
A dict representation of the observations.
"""
return _to_dict(
flatten_obs=flatten_obs,
keys=ALL_OBSERVATIONS,
observation_config=observation_config)
\ No newline at end of file
...@@ -1695,6 +1695,9 @@ class PlacementCost(object): ...@@ -1695,6 +1695,9 @@ class PlacementCost(object):
mod.set_orientation(orientation) mod.set_orientation(orientation)
def update_port_sides(self): def update_port_sides(self):
"""
Define Port "Side" by its location on canvas
"""
pass pass
def snap_ports_to_edges(self): def snap_ports_to_edges(self):
......
...@@ -7,6 +7,8 @@ from absl.flags import argparse_flags ...@@ -7,6 +7,8 @@ from absl.flags import argparse_flags
from absl import app from absl import app
from Plc_client import plc_client_os as plc_client_os from Plc_client import plc_client_os as plc_client_os
from Plc_client import placement_util_os as placement_util from Plc_client import placement_util_os as placement_util
from Plc_client import observation_extractor_os as observation_extractor
from Plc_client import observation_config
try: try:
from Plc_client import plc_client as plc_client from Plc_client import plc_client as plc_client
...@@ -460,7 +462,70 @@ class PlacementCostTest(): ...@@ -460,7 +462,70 @@ class PlacementCostTest():
raise AssertionError ("false") raise AssertionError ("false")
except AssertionError: except AssertionError:
print("[ERROR PLACEMENT UTIL] Saved PLC Discrepency found at line {}".format(str(idx))) print("[ERROR PLACEMENT UTIL] Saved PLC Discrepency found at line {}".format(str(idx)))
# if keep plc file for detailed comparison
if not keep_save_file:
os.remove('save_test_gl.plc')
os.remove('save_test_os.plc')
def test_observation_extractor(self):
"""
plc = placement_util.create_placement_cost(
netlist_file=netlist_file, init_placement='')
plc.set_canvas_size(300, 200)
plc.set_placement_grid(9, 4)
plc.unplace_all_nodes()
# Manually adds I/O port locations, this step is not needed for real
# netlists.
plc.update_node_coords('P0', 0.5, 100) # Left
plc.update_node_coords('P1', 150, 199.5) # Top
plc.update_port_sides()
plc.snap_ports_to_edges()
self.extractor = observation_extractor.ObservationExtractor(
plc=plc, observation_config=self._observation_config)
"""
try:
assert self.PLC_PATH
except AssertionError:
print("[ERROR OBSERVATION EXTRACTOR TEST] Facilitate required .plc file")
# Using the default edge/node
self._observation_config = observation_config.ObservationConfig(
max_num_edges=28400, max_num_nodes=5000, max_grid_size=128)
self.plc_util = placement_util.create_placement_cost(
plc_client=plc_client,
netlist_file=self.NETLIST_PATH,
init_placement=self.PLC_PATH
)
self.plc_util_os = placement_util.create_placement_cost(
plc_client=plc_client_os,
netlist_file=self.NETLIST_PATH,
init_placement=self.PLC_PATH
)
self.extractor = observation_extractor.ObservationExtractor(
plc=self.plc_util, observation_config=self._observation_config
)
self.extractor_os = observation_extractor.ObservationExtractor(
plc=self.plc_util_os, observation_config=self._observation_config
)
# Static features that are invariant across training steps
static_feature_gl = self.extractor._extract_static_features()
static_feature_os = self.extractor_os._extract_static_features()
for feature_gl, feature_os in zip(static_feature_gl, static_feature_os):
assert (static_feature_gl[feature_gl] == static_feature_os[feature_os]).all()
print(" ++++++++++++++++++++++++++++++++++++++++")
print(" +++ TEST OBSERVATION EXTRACTOR: PASS +++")
print(" ++++++++++++++++++++++++++++++++++++++++")
def test_place_node(self):
pass
def test_environment(self): def test_environment(self):
pass pass
...@@ -517,8 +582,9 @@ def main(args): ...@@ -517,8 +582,9 @@ def main(args):
# PCT.test_metadata() # PCT.test_metadata()
PCT.test_proxy_cost() PCT.test_proxy_cost()
PCT.test_placement_util() # PCT.test_placement_util()
# PCT.test_miscellaneous() # PCT.test_miscellaneous()
PCT.test_observation_extractor()
if __name__ == '__main__': if __name__ == '__main__':
app.run(main, flags_parser=parse_flags) app.run(main, flags_parser=parse_flags)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment