Commit 7beec4a5 by Dinple

environment WIP

parent 7789080d
...@@ -85,5 +85,11 @@ $$ ...@@ -85,5 +85,11 @@ $$
Notice a smoothing range can be set for congestion. This is only applied to congestion due to net routing which by counting adjacent cells and adding the averaged congestion to these adjacent cells. More details are provided in the document above. Notice a smoothing range can be set for congestion. This is only applied to congestion due to net routing which by counting adjacent cells and adding the averaged congestion to these adjacent cells. More details are provided in the document above.
## Placement Util
**Disclaimer: We DO NOT own the content of placement_util_os.py. All rights belong to Google Authors. This is a modified version of placement_util.py and we are including in the repo for the sake of testing. Original Code can be viewed [here](https://github.com/google-research/circuit_training/blob/main/circuit_training/environment/placement_util.py)**.
## Observation Extractor
**Disclaimer: We DO NOT own the content of observation_extractor_os.py. All rights belong to Google Authors. This is a modified version of observation_extractor.py and we are including in the repo for the sake of testing. Original Code can be viewed [here](https://github.com/google-research/circuit_training/blob/main/circuit_training/environment/observation_extractor.py)**.
...@@ -8,6 +8,7 @@ from absl import app ...@@ -8,6 +8,7 @@ from absl import app
from Plc_client import plc_client_os as plc_client_os from Plc_client import plc_client_os as plc_client_os
from Plc_client import placement_util_os as placement_util from Plc_client import placement_util_os as placement_util
from Plc_client import observation_extractor_os as observation_extractor from Plc_client import observation_extractor_os as observation_extractor
from Plc_client import environment_os as environment
from Plc_client import observation_config from Plc_client import observation_config
try: try:
...@@ -484,6 +485,7 @@ class PlacementCostTest(): ...@@ -484,6 +485,7 @@ class PlacementCostTest():
self.extractor = observation_extractor.ObservationExtractor( self.extractor = observation_extractor.ObservationExtractor(
plc=plc, observation_config=self._observation_config) plc=plc, observation_config=self._observation_config)
""" """
print("############################ TEST OBSERVATION EXTRACTOR ############################")
try: try:
assert self.PLC_PATH assert self.PLC_PATH
except AssertionError: except AssertionError:
...@@ -527,7 +529,32 @@ class PlacementCostTest(): ...@@ -527,7 +529,32 @@ class PlacementCostTest():
pass pass
def test_environment(self): def test_environment(self):
pass print("############################ TEST ENVIRONMENT ############################")
# Source: https://github.com/google-research/circuit_training/blob/d5e454e5bcd153a95d320f664af0d1b378aace7b/circuit_training/environment/environment_test.py#L39
def random_action(mask):
valid_actions, = np.nonzero(mask.flatten())
if len(valid_actions): # pylint: disable=g-explicit-length-test
return np.random.choice(valid_actions)
# If there is no valid choice, then `[0]` is returned which results in an
# infeasable action ending the episode.
return 0
env = environment.CircuitEnv(
_plc=plc_client,
create_placement_cost_fn=placement_util.create_placement_cost,
netlist_file=self.NETLIST_PATH,
init_placement=self.PLC_PATH)
env_os = environment.CircuitEnv(
_plc=plc_client_os,
create_placement_cost_fn=placement_util.create_placement_cost,
netlist_file=self.NETLIST_PATH,
init_placement=self.PLC_PATH)
print(" ++++++++++++++++++++++++++++++")
print(" +++ TEST ENVIRONMENT: PASS +++")
print(" ++++++++++++++++++++++++++++++")
def parse_flags(argv): def parse_flags(argv):
parser = argparse_flags.ArgumentParser(description='An argparse + app.run example') parser = argparse_flags.ArgumentParser(description='An argparse + app.run example')
...@@ -584,7 +611,8 @@ def main(args): ...@@ -584,7 +611,8 @@ def main(args):
PCT.test_proxy_cost() PCT.test_proxy_cost()
# PCT.test_placement_util() # PCT.test_placement_util()
# PCT.test_miscellaneous() # PCT.test_miscellaneous()
PCT.test_observation_extractor() # PCT.test_observation_extractor()
PCT.test_environment()
if __name__ == '__main__': if __name__ == '__main__':
app.run(main, flags_parser=parse_flags) app.run(main, flags_parser=parse_flags)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment