Commit 824b68b7 by lvzhengyang

add framework for training

parent 96539d53
......@@ -29,10 +29,15 @@ model = Policy(actor=actor_net,
nodes_xy, mask_visited = envs.reset()
for episode in range(max_episode_num):
v, h = model.act(nodes_xy, mask_visited=mask_visited)
# v/h is a tensor of shape [#batch_size] on model.device
mask_visited, rewards, done, _ = envs.step(v, h)
if done:
model.learn(nodes_xy, rewards)
nodes_xy, mask_visited = envs.reset()
for i in range(num_nodes):
if i == 0:
u0 = model.first_act(nodes_xy)
mask_visited, rewards, done, _ = envs.step(u0)
continue
v, h = model.act(nodes_xy, mask_visited=mask_visited)
# v/h is a tensor of shape [#batch_size] on model.device
mask_visited, rewards, done, _ = envs.step(v, h)
if done:
model.learn(nodes_xy, rewards)
nodes_xy, mask_visited = envs.reset()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment