Commit b5fadf19 by Dinple

eval ct ver1.0

parent f72beddb
...@@ -4,14 +4,19 @@ Flows/job ...@@ -4,14 +4,19 @@ Flows/job
Flows/util/__pycache__ Flows/util/__pycache__
CodeElements/*/*/__pycache__ CodeElements/*/*/__pycache__
CodeElements/Plc_client/test/ CodeElements/Plc_client/test/
CodeElements/Plc_client/test/*/* CodeElements/Plc_client/test/*
CodeElements/Plc_client/__pycache__/* CodeElements/Plc_client/__pycache__/*
CodeElements/Plc_client/proto_reader.py CodeElements/Plc_client/proto_reader.py
CodeElements/Plc_client/plc_client.py CodeElements/Plc_client/plc_client.py
CodeElements/failed_proxy_plc/* CodeElements/failed_proxy_plc/*
CodeElements/EvalCT/saved_policy/* CodeElements/EvalCT/test/g657_ub5_nruns10_c5_r3_v3_rc1
CodeElements/EvalCT/test/*
CodeElements/EvalCT/snapshot* CodeElements/EvalCT/snapshot*
CodeElements/EvalCT/circuit_training CodeElements/EvalCT/circuit_training
CodeElements/EvalCT/__pycache__/ CodeElements/EvalCT/__pycache__/
CodeElements/EvalCT/eval_run*.plc CodeElements/EvalCT/eval_run*.plc
CodeElements/EvalCT/eval_run*.plc
CodeElements/EvalCT/saved_policy/run_00/111/train/*
CodeElements/EvalCT/saved_policy/run_00/111/snapshot*.plc
CodeElements/EvalCT/saved_policy/run_00/111/rl*.plc
CodeElements/EvalCT/saved_policy/run_os_64128_g657_ub5_nruns10_c5_r3_v3_rc1
...@@ -31,8 +31,8 @@ Example ...@@ -31,8 +31,8 @@ Example
$ python3 -m eval_ct --netlist ./test/ariane/netlist.pb.txt\ $ python3 -m eval_ct --netlist ./test/ariane/netlist.pb.txt\
--plc ./test/ariane/initial.plc\ --plc ./test/ariane/initial.plc\
--rundir run_os_64128_g657_ub5_nruns10_c5_r3_v3_rc1 --rundir run_00\
--ckptID policy_checkpoint_0000103984
""" """
...@@ -80,7 +80,7 @@ class InfoMetric(py_metric.PyStepMetric): ...@@ -80,7 +80,7 @@ class InfoMetric(py_metric.PyStepMetric):
def reset(self): def reset(self):
self._buffer.clear() self._buffer.clear()
def evaulate(model_dir, create_env_fn): def evaulate(model_dir, ckpt_id, create_env_fn):
# Create the path for the serialized greedy policy. # Create the path for the serialized greedy policy.
policy_saved_model_path = os.path.join(model_dir, policy_saved_model_path = os.path.join(model_dir,
learner.POLICY_SAVED_MODEL_DIR, learner.POLICY_SAVED_MODEL_DIR,
...@@ -94,7 +94,7 @@ def evaulate(model_dir, create_env_fn): ...@@ -94,7 +94,7 @@ def evaulate(model_dir, create_env_fn):
policy_saved_chkpt_path = os.path.join(model_dir, policy_saved_chkpt_path = os.path.join(model_dir,
learner.POLICY_SAVED_MODEL_DIR, learner.POLICY_SAVED_MODEL_DIR,
"checkpoints/policy_checkpoint_0000107200") "checkpoints", ckpt_id)
try: try:
assert os.path.isdir(policy_saved_chkpt_path) assert os.path.isdir(policy_saved_chkpt_path)
print("#[POLICY SAVED CHECKPOINT PATH] " + policy_saved_chkpt_path) print("#[POLICY SAVED CHECKPOINT PATH] " + policy_saved_chkpt_path)
...@@ -146,11 +146,13 @@ def evaulate(model_dir, create_env_fn): ...@@ -146,11 +146,13 @@ def evaulate(model_dir, create_env_fn):
def main(args): def main(args):
NETLIST_FILE = args.netlist NETLIST_FILE = args.netlist
INIT_PLACEMENT = args.plc INIT_PLACEMENT = args.plc
POLICY_CHECKPOINT_ID = args.ckptID
GLOBAL_SEED = 111 GLOBAL_SEED = 111
CD_RUNTIME = False CD_RUNTIME = False
RUN_NAME = args.rundir RUN_NAME = args.rundir
# extract eval testcase name
EVAL_TESTCASE = re.search("/test/(.+?)/netlist.pb.txt", NETLIST_FILE).group(1) EVAL_TESTCASE = re.search("/test/(.+?)/netlist.pb.txt", NETLIST_FILE).group(1)
print(EVAL_TESTCASE)
create_env_fn = functools.partial( create_env_fn = functools.partial(
environment.create_circuit_environment, environment.create_circuit_environment,
...@@ -164,7 +166,7 @@ def main(args): ...@@ -164,7 +166,7 @@ def main(args):
) )
evaulate(model_dir=os.path.join("./saved_policy", RUN_NAME, str(GLOBAL_SEED)), evaulate(model_dir=os.path.join("./saved_policy", RUN_NAME, str(GLOBAL_SEED)),
create_env_fn=create_env_fn) ckpt_id=POLICY_CHECKPOINT_ID, create_env_fn=create_env_fn)
def parse_flags(argv): def parse_flags(argv):
parser = argparse_flags.ArgumentParser( parser = argparse_flags.ArgumentParser(
...@@ -175,6 +177,9 @@ def parse_flags(argv): ...@@ -175,6 +177,9 @@ def parse_flags(argv):
help="Path to plc in .plc") help="Path to plc in .plc")
parser.add_argument("--rundir", required=True, parser.add_argument("--rundir", required=True,
help="Path to run directory that contains saved policies") help="Path to run directory that contains saved policies")
parser.add_argument("--ckptID", required=True,
help="Policy checkpoint ID")
return parser.parse_args(argv[1:]) return parser.parse_args(argv[1:])
if __name__ == '__main__': if __name__ == '__main__':
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment