Commit b5fadf19 by Dinple

eval ct ver1.0

parent f72beddb
......@@ -4,14 +4,19 @@ Flows/job
Flows/util/__pycache__
CodeElements/*/*/__pycache__
CodeElements/Plc_client/test/
CodeElements/Plc_client/test/*/*
CodeElements/Plc_client/test/*
CodeElements/Plc_client/__pycache__/*
CodeElements/Plc_client/proto_reader.py
CodeElements/Plc_client/plc_client.py
CodeElements/failed_proxy_plc/*
CodeElements/EvalCT/saved_policy/*
CodeElements/EvalCT/test/*
CodeElements/EvalCT/test/g657_ub5_nruns10_c5_r3_v3_rc1
CodeElements/EvalCT/snapshot*
CodeElements/EvalCT/circuit_training
CodeElements/EvalCT/__pycache__/
CodeElements/EvalCT/eval_run*.plc
CodeElements/EvalCT/eval_run*.plc
CodeElements/EvalCT/saved_policy/run_00/111/train/*
CodeElements/EvalCT/saved_policy/run_00/111/snapshot*.plc
CodeElements/EvalCT/saved_policy/run_00/111/rl*.plc
CodeElements/EvalCT/saved_policy/run_os_64128_g657_ub5_nruns10_c5_r3_v3_rc1
......@@ -31,8 +31,8 @@ Example
$ python3 -m eval_ct --netlist ./test/ariane/netlist.pb.txt\
--plc ./test/ariane/initial.plc\
--rundir run_os_64128_g657_ub5_nruns10_c5_r3_v3_rc1
--rundir run_00\
--ckptID policy_checkpoint_0000103984
"""
......@@ -80,7 +80,7 @@ class InfoMetric(py_metric.PyStepMetric):
def reset(self):
self._buffer.clear()
def evaulate(model_dir, create_env_fn):
def evaulate(model_dir, ckpt_id, create_env_fn):
# Create the path for the serialized greedy policy.
policy_saved_model_path = os.path.join(model_dir,
learner.POLICY_SAVED_MODEL_DIR,
......@@ -94,7 +94,7 @@ def evaulate(model_dir, create_env_fn):
policy_saved_chkpt_path = os.path.join(model_dir,
learner.POLICY_SAVED_MODEL_DIR,
"checkpoints/policy_checkpoint_0000107200")
"checkpoints", ckpt_id)
try:
assert os.path.isdir(policy_saved_chkpt_path)
print("#[POLICY SAVED CHECKPOINT PATH] " + policy_saved_chkpt_path)
......@@ -146,11 +146,13 @@ def evaulate(model_dir, create_env_fn):
def main(args):
NETLIST_FILE = args.netlist
INIT_PLACEMENT = args.plc
POLICY_CHECKPOINT_ID = args.ckptID
GLOBAL_SEED = 111
CD_RUNTIME = False
RUN_NAME = args.rundir
# extract eval testcase name
EVAL_TESTCASE = re.search("/test/(.+?)/netlist.pb.txt", NETLIST_FILE).group(1)
print(EVAL_TESTCASE)
create_env_fn = functools.partial(
environment.create_circuit_environment,
......@@ -164,7 +166,7 @@ def main(args):
)
evaulate(model_dir=os.path.join("./saved_policy", RUN_NAME, str(GLOBAL_SEED)),
create_env_fn=create_env_fn)
ckpt_id=POLICY_CHECKPOINT_ID, create_env_fn=create_env_fn)
def parse_flags(argv):
parser = argparse_flags.ArgumentParser(
......@@ -175,6 +177,9 @@ def parse_flags(argv):
help="Path to plc in .plc")
parser.add_argument("--rundir", required=True,
help="Path to run directory that contains saved policies")
parser.add_argument("--ckptID", required=True,
help="Policy checkpoint ID")
return parser.parse_args(argv[1:])
if __name__ == '__main__':
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment