Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
R
REST
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
lvzhengyang
REST
Commits
f26f7266
Commit
f26f7266
authored
Sep 02, 2022
by
lvzhengyang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
build framework for training
parent
7f1d8bef
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
48 additions
and
7 deletions
+48
-7
__pycache__/agent.cpython-38.pyc
+0
-0
__pycache__/model.cpython-38.pyc
+0
-0
agent.py
+2
-2
env/__pycache__/env.cpython-38.pyc
+0
-0
env/env.py
+7
-3
test.py
+2
-2
train.py
+37
-0
No files found.
__pycache__/agent.cpython-38.pyc
View file @
f26f7266
No preview for this file type
__pycache__/model.cpython-38.pyc
View file @
f26f7266
No preview for this file type
agent.py
View file @
f26f7266
...
...
@@ -74,8 +74,8 @@ class Policy(nn.Module):
def
__init__
(
self
,
actor
,
critic
,
obs_space
,
action_space
,
obs_space
=
None
,
action_space
=
None
,
optimizer
=
None
,
lr
=
2.5e-4
,
weight_decay
=
5e-4
,
...
...
env/__pycache__/env.cpython-38.pyc
View file @
f26f7266
No preview for this file type
env/env.py
View file @
f26f7266
...
...
@@ -25,9 +25,8 @@ class RSMTEnv(gym.Env):
})
self
.
action_space
=
spaces
.
Dict
({
"u"
:
spaces
.
Discrete
(
num_nodes
),
"w"
:
spaces
.
Discrete
(
num_nodes
),
"s"
:
spaces
.
Discrete
(
2
),
"v"
:
spaces
.
Discrete
(
num_nodes
),
"h"
:
spaces
.
Discrete
(
num_nodes
),
})
self
.
mask_unvisited
=
None
...
...
@@ -50,3 +49,7 @@ class RSMTEnv(gym.Env):
def
close
(
self
):
pass
def
create_RSMTEnv
(
*
args
,
**
kwarg
):
env
=
RSMTEnv
(
*
args
,
**
kwarg
)
return
env
\ No newline at end of file
test.py
View file @
f26f7266
...
...
@@ -2,7 +2,7 @@ import torch
import
pdb
from
agent
import
Actor
,
Critic
,
Policy
from
env.env
import
RSMTEnv
from
env.env
import
create_
RSMTEnv
def
update_mask
(
mask
,
batch_size
,
node_visited
):
for
i
in
range
(
batch_size
):
...
...
@@ -13,7 +13,7 @@ batch_size = 4
num_nodes
=
8
nodes
=
torch
.
randn
(
batch_size
,
num_nodes
,
2
)
env
=
RSMTEnv
(
num_nodes
=
num_nodes
,
pos_l
=
0
,
pos_h
=
100
)
env
=
create_
RSMTEnv
(
num_nodes
=
num_nodes
,
pos_l
=
0
,
pos_h
=
100
)
actor_net
=
Actor
()
critic_net
=
Critic
()
...
...
train.py
0 → 100644
View file @
f26f7266
import
torch
from
agent
import
Actor
,
Critic
,
Policy
from
env.env
import
create_RSMTEnv
import
gym
import
pdb
batch_size
=
4
num_nodes
=
8
max_episode_num
=
40000
# 40k
pos_l
=
0.0
pos_h
=
100.0
env_fns
=
[
lambda
:
create_RSMTEnv
(
num_nodes
=
num_nodes
,
pos_l
=
pos_l
,
pos_h
=
pos_h
)
for
i
in
range
(
batch_size
)
]
envs
=
gym
.
vector
.
AsyncVectorEnv
(
env_fns
)
actor_net
=
Actor
()
critic_net
=
Critic
()
model
=
Policy
(
actor
=
actor_net
,
critic
=
critic_net
,
)
nodes_xy
,
mask_visited
=
envs
.
reset
()
for
episode
in
range
(
max_episode_num
):
v
,
h
=
model
.
act
(
nodes_xy
,
mask_visited
=
mask_visited
)
# v/h is a tensor of shape [#batch_size] on model.device
mask_visited
,
rewards
,
done
,
_
=
envs
.
step
(
v
,
h
)
if
done
:
model
.
learn
(
nodes_xy
,
rewards
)
nodes_xy
,
mask_visited
=
envs
.
reset
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment