Commit f26f7266 by lvzhengyang

build framework for training

parent 7f1d8bef
......@@ -74,8 +74,8 @@ class Policy(nn.Module):
def __init__(self,
actor,
critic,
obs_space,
action_space,
obs_space=None,
action_space=None,
optimizer=None,
lr=2.5e-4,
weight_decay=5e-4,
......
......@@ -25,9 +25,8 @@ class RSMTEnv(gym.Env):
})
self.action_space = spaces.Dict({
"u": spaces.Discrete(num_nodes),
"w": spaces.Discrete(num_nodes),
"s": spaces.Discrete(2),
"v": spaces.Discrete(num_nodes),
"h": spaces.Discrete(num_nodes),
})
self.mask_unvisited = None
......@@ -49,4 +48,8 @@ class RSMTEnv(gym.Env):
# return obs, reward, done, info
def close(self):
pass
\ No newline at end of file
pass
def create_RSMTEnv(*args, **kwarg):
env = RSMTEnv(*args, **kwarg)
return env
\ No newline at end of file
......@@ -2,7 +2,7 @@ import torch
import pdb
from agent import Actor, Critic, Policy
from env.env import RSMTEnv
from env.env import create_RSMTEnv
def update_mask(mask, batch_size, node_visited):
for i in range(batch_size):
......@@ -13,7 +13,7 @@ batch_size = 4
num_nodes = 8
nodes = torch.randn(batch_size, num_nodes, 2)
env = RSMTEnv(num_nodes=num_nodes, pos_l=0, pos_h=100)
env = create_RSMTEnv(num_nodes=num_nodes, pos_l=0, pos_h=100)
actor_net = Actor()
critic_net = Critic()
......
import torch
from agent import Actor, Critic, Policy
from env.env import create_RSMTEnv
import gym
import pdb
batch_size = 4
num_nodes = 8
max_episode_num = 40000 # 40k
pos_l = 0.0
pos_h = 100.0
env_fns = [lambda: create_RSMTEnv(num_nodes=num_nodes,
pos_l=pos_l,
pos_h=pos_h)
for i in range (batch_size)
]
envs = gym.vector.AsyncVectorEnv(env_fns)
actor_net = Actor()
critic_net = Critic()
model = Policy(actor=actor_net,
critic=critic_net,
)
nodes_xy, mask_visited = envs.reset()
for episode in range(max_episode_num):
v, h = model.act(nodes_xy, mask_visited=mask_visited)
# v/h is a tensor of shape [#batch_size] on model.device
mask_visited, rewards, done, _ = envs.step(v, h)
if done:
model.learn(nodes_xy, rewards)
nodes_xy, mask_visited = envs.reset()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment