Commit 824b68b7 by lvzhengyang

add framework for training

parent 96539d53
...@@ -29,10 +29,15 @@ model = Policy(actor=actor_net, ...@@ -29,10 +29,15 @@ model = Policy(actor=actor_net,
nodes_xy, mask_visited = envs.reset() nodes_xy, mask_visited = envs.reset()
for episode in range(max_episode_num): for episode in range(max_episode_num):
v, h = model.act(nodes_xy, mask_visited=mask_visited) for i in range(num_nodes):
# v/h is a tensor of shape [#batch_size] on model.device if i == 0:
mask_visited, rewards, done, _ = envs.step(v, h) u0 = model.first_act(nodes_xy)
if done: mask_visited, rewards, done, _ = envs.step(u0)
model.learn(nodes_xy, rewards) continue
nodes_xy, mask_visited = envs.reset() v, h = model.act(nodes_xy, mask_visited=mask_visited)
# v/h is a tensor of shape [#batch_size] on model.device
mask_visited, rewards, done, _ = envs.step(v, h)
if done:
model.learn(nodes_xy, rewards)
nodes_xy, mask_visited = envs.reset()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment