Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
R
REST
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
lvzhengyang
REST
Commits
131c30fa
Commit
131c30fa
authored
Sep 02, 2022
by
lvzhengyang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix errors on EPTM module in returning w and s
parent
592f1b41
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
56 deletions
+68
-56
__pycache__/agent.cpython-38.pyc
+0
-0
__pycache__/model.cpython-38.pyc
+0
-0
agent.py
+48
-29
model.py
+9
-13
test.py
+11
-14
No files found.
__pycache__/agent.cpython-38.pyc
View file @
131c30fa
No preview for this file type
__pycache__/model.cpython-38.pyc
View file @
131c30fa
No preview for this file type
agent.py
View file @
131c30fa
...
@@ -38,8 +38,8 @@ class Actor(nn.Module):
...
@@ -38,8 +38,8 @@ class Actor(nn.Module):
def
forward
(
self
,
nodes
,
mask_visited
=
None
,
mask_unvisited
=
None
):
def
forward
(
self
,
nodes
,
mask_visited
=
None
,
mask_unvisited
=
None
):
e
=
self
.
encoder
(
nodes
)
e
=
self
.
encoder
(
nodes
)
u_probs
,
w
_probs
,
s_probs
=
self
.
decoder
(
e
,
mask_visited
,
mask_unvisited
)
u_probs
,
ws_probs
=
self
.
decoder
(
e
,
mask_visited
,
mask_unvisited
)
return
u_probs
,
w
_probs
,
s_probs
return
u_probs
,
ws_probs
class
Critic
(
nn
.
Module
):
class
Critic
(
nn
.
Module
):
def
__init__
(
self
,
dim_e
=
D
,
dim_c
=
256
)
->
None
:
def
__init__
(
self
,
dim_e
=
D
,
dim_c
=
256
)
->
None
:
...
@@ -57,6 +57,7 @@ class Critic(nn.Module):
...
@@ -57,6 +57,7 @@ class Critic(nn.Module):
def
forward
(
self
,
nodes
):
def
forward
(
self
,
nodes
):
"""
"""
@param nodes: [#batch, #num_nodes, 2] in dtype float
@param nodes: [#batch, #num_nodes, 2] in dtype float
@return Expection for each batch, [#num_batch]
"""
"""
e
=
self
.
encoder
(
nodes
)
# [#batch_size, #num_nodes, D]
e
=
self
.
encoder
(
nodes
)
# [#batch_size, #num_nodes, D]
glimpse_w
=
self
.
g
(
torch
.
tanh
(
e
))
.
squeeze_
()
# [#batch_size, #num_nodes]
glimpse_w
=
self
.
g
(
torch
.
tanh
(
e
))
.
squeeze_
()
# [#batch_size, #num_nodes]
...
@@ -81,11 +82,17 @@ class Policy():
...
@@ -81,11 +82,17 @@ class Policy():
self
.
actor
=
actor
self
.
actor
=
actor
self
.
critic
=
critic
self
.
critic
=
critic
self
.
log_probs_buf
=
{
"u0"
:
None
,
"u"
:
[],
"ws"
:
[]
}
def
first_act
(
self
,
nodes
):
def
first_act
(
self
,
nodes
):
"""
"""
@brief perform action for t == 0, with the query = 0, get u0
@brief perform action for t == 0, with the query = 0, get u0
@param nodes: [#num_batch, #num_nodes, 2]
@param nodes: [#num_batch, #num_nodes, 2]
@return points: [#num_batch],
dist of probs [#num_batch, #num_nodes
]
@return points: [#num_batch],
and the corresponding log_prob [#num_batch
]
"""
"""
probs
=
self
.
actor
.
get_u0_probs
(
nodes
)
# [#num_batch, #num_nodes]
probs
=
self
.
actor
.
get_u0_probs
(
nodes
)
# [#num_batch, #num_nodes]
dist
=
Categorical
(
probs
)
dist
=
Categorical
(
probs
)
...
@@ -96,7 +103,8 @@ class Policy():
...
@@ -96,7 +103,8 @@ class Policy():
_last_Eu
.
append
(
self
.
actor
.
_node_e
[
i
,
u0
[
i
]])
_last_Eu
.
append
(
self
.
actor
.
_node_e
[
i
,
u0
[
i
]])
self
.
actor
.
decoder
.
_last_Eu
=
torch
.
stack
(
_last_Eu
)
self
.
actor
.
decoder
.
_last_Eu
=
torch
.
stack
(
_last_Eu
)
self
.
actor
.
decoder
.
_last_Ev
=
self
.
actor
.
decoder
.
_last_Eu
.
clone
()
self
.
actor
.
decoder
.
_last_Ev
=
self
.
actor
.
decoder
.
_last_Eu
.
clone
()
return
u0
,
dist
self
.
log_probs_buf
[
'u0'
]
=
dist
.
log_prob
(
u0
)
return
u0
def
act
(
self
,
nodes
,
mask_visited
=
None
,
mask_unvisited
=
None
):
def
act
(
self
,
nodes
,
mask_visited
=
None
,
mask_unvisited
=
None
):
"""
"""
...
@@ -104,50 +112,61 @@ class Policy():
...
@@ -104,50 +112,61 @@ class Policy():
@param nodes: [#num_batch, #num_nodes, 2]
@param nodes: [#num_batch, #num_nodes, 2]
@param mask_visited/mask_unvisited: [#num_batch, #num_nodes]
@param mask_visited/mask_unvisited: [#num_batch, #num_nodes]
TODO: gather the input into one obs: contains a batch of obs
TODO: gather the input into one obs: contains a batch of obs
@return u/w/s: [#num_batch], and their corresponding log_prob [#num_batch]
"""
"""
if
mask_visited
==
None
and
not
mask_unvisited
==
None
:
if
mask_visited
==
None
and
not
mask_unvisited
==
None
:
mask_visited
=
~
mask_unvisited
mask_visited
=
~
mask_unvisited
if
mask_unvisited
==
None
and
not
mask_visited
==
None
:
if
mask_unvisited
==
None
and
not
mask_visited
==
None
:
mask_unvisited
=
~
mask_visited
mask_unvisited
=
~
mask_visited
u_probs
,
_w_probs
,
s_probs
=
self
.
actor
(
nodes
,
num_nodes
=
nodes
.
shape
[
1
]
u_probs
,
ws_probs
=
self
.
actor
(
nodes
,
mask_visited
=
mask_visited
,
mask_unvisited
=
mask_unvisited
)
mask_visited
=
mask_visited
,
mask_unvisited
=
mask_unvisited
)
u_dist
=
Categorical
(
u_probs
)
u_dist
=
Categorical
(
u_probs
)
_w_dist
=
Categorical
(
_w_probs
)
# wait to be choice by s
ws_dist
=
Categorical
(
ws_probs
)
s_dist
=
Categorical
(
s_probs
)
u
=
u_dist
.
sample
()
u
=
u_dist
.
sample
()
_w
=
_w_dist
.
sample
()
_w
=
ws_dist
.
sample
()
s
=
s_dist
.
sample
()
self
.
log_probs_buf
[
'u'
]
.
append
(
u_dist
.
log_prob
(
u
))
self
.
log_probs_buf
[
'ws'
]
.
append
(
ws_dist
.
log_prob
(
_w
))
s
=
torch
.
where
(
_w
<
num_nodes
,
int
(
0
),
int
(
1
))
.
to
(
_w
.
device
)
w
=
torch
.
where
(
_w
<
num_nodes
,
_w
,
_w
-
num_nodes
)
batch_size
=
u
.
shape
[
0
]
batch_size
=
u
.
shape
[
0
]
_last_Eu
=
[]
_last_Eu
=
[]
for
i
in
range
(
batch_size
):
_last_Eu
.
append
(
self
.
actor
.
_node_e
[
i
,
u
[
i
]])
self
.
actor
.
decoder
.
_last_Eu
=
torch
.
stack
(
_last_Eu
)
_last_Ew
=
[]
_last_Ew
=
[]
_last_Ev
=
[]
_last_Ev
=
[]
_last_Eh
=
[]
_last_Eh
=
[]
w
=
[]
v
=
[]
w_probs
=
[]
h
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
_last_Eu
.
append
(
self
.
actor
.
_node_e
[
i
,
u
[
i
]])
_last_Ew
.
append
(
self
.
actor
.
_node_e
[
i
,
w
[
i
]])
if
s
[
i
]
==
0
:
if
s
[
i
]
==
0
:
_last_Ev
.
append
(
self
.
actor
.
_node_e
[
i
,
u
[
i
]])
_last_Ev
.
append
(
self
.
actor
.
_node_e
[
i
,
u
[
i
]])
_last_Eh
.
append
(
self
.
actor
.
_node_e
[
i
,
_w
[
i
][
0
]])
_last_Eh
.
append
(
self
.
actor
.
_node_e
[
i
,
w
[
i
]])
_last_Ew
.
append
(
self
.
actor
.
_node_e
[
i
,
_w
[
i
][
0
]])
v
.
append
(
u
[
i
])
w
.
append
(
_w
[
i
,
0
])
h
.
append
(
w
[
i
])
w_probs
.
append
(
_w_probs
[
i
,
0
])
else
:
else
:
_last_Ev
.
append
(
self
.
actor
.
_node_e
[
i
,
_w
[
i
][
1
]])
_last_Ev
.
append
(
self
.
actor
.
_node_e
[
i
,
w
[
i
]])
_last_Eh
.
append
(
self
.
actor
.
_node_e
[
i
,
u
[
i
]])
_last_Eh
.
append
(
self
.
actor
.
_node_e
[
i
,
u
[
i
]])
_last_Ew
.
append
(
self
.
actor
.
_node_e
[
i
,
_w
[
i
][
1
]])
v
.
append
(
w
[
i
])
w
.
append
(
_w
[
i
,
1
])
h
.
append
(
u
[
i
])
w_probs
.
append
(
_w_probs
[
i
,
0
])
self
.
actor
.
decoder
.
_last_Eu
=
torch
.
stack
(
_last_Eu
)
self
.
actor
.
decoder
.
_last_Ew
=
torch
.
stack
(
_last_Ew
)
self
.
actor
.
decoder
.
_last_Ev
=
torch
.
stack
(
_last_Ev
)
self
.
actor
.
decoder
.
_last_Ev
=
torch
.
stack
(
_last_Ev
)
self
.
actor
.
decoder
.
_last_Eh
=
torch
.
stack
(
_last_Eh
)
self
.
actor
.
decoder
.
_last_Eh
=
torch
.
stack
(
_last_Eh
)
self
.
actor
.
decoder
.
_last_Ew
=
torch
.
stack
(
_last_Ew
)
# get the choiced w
w
=
torch
.
tensor
(
w
,
device
=
u
.
device
)
w_probs
=
torch
.
stack
(
w_probs
)
w_dist
=
Categorical
(
w_probs
)
return
u
,
w
,
s
,
u_dist
,
w_dist
,
s_dist
v
=
torch
.
stack
(
v
)
h
=
torch
.
stack
(
h
)
return
v
,
h
def
learn
(
self
,
rewards
):
"""
@brief update the parameters of actor and critic
@param rewards: reward of an episode in a batch. [#num_batch]
@note the env has returned 'done' signal
"""
pass
model.py
View file @
131c30fa
...
@@ -177,8 +177,6 @@ class EPTM(nn.Module):
...
@@ -177,8 +177,6 @@ class EPTM(nn.Module):
super
(
EPTM
,
self
)
.
__init__
()
super
(
EPTM
,
self
)
.
__init__
()
self
.
PTM_0
=
PTM
(
dim_e
=
dim_e
,
dim_q
=
dim_q
)
self
.
PTM_0
=
PTM
(
dim_e
=
dim_e
,
dim_q
=
dim_q
)
self
.
PTM_1
=
PTM
(
dim_e
=
dim_e
,
dim_q
=
dim_q
)
self
.
PTM_1
=
PTM
(
dim_e
=
dim_e
,
dim_q
=
dim_q
)
# to get s
self
.
q_encode
=
nn
.
Linear
(
dim_q
+
dim_e
,
2
,
bias
=
False
)
self
.
C
=
10.0
self
.
C
=
10.0
def
forward
(
self
,
e
,
q
,
mask
=
None
):
def
forward
(
self
,
e
,
q
,
mask
=
None
):
...
@@ -187,8 +185,10 @@ class EPTM(nn.Module):
...
@@ -187,8 +185,10 @@ class EPTM(nn.Module):
@param q: [#num_batch, 360]
@param q: [#num_batch, 360]
@param mask: [#num_batch, #num_nodes], with masked points set to True and the reset are False
@param mask: [#num_batch, #num_nodes], with masked points set to True and the reset are False
set unvisited points True
set unvisited points True
@return [#num_batch, 2, #num_nodes]
@return [#num_batch, 2 * #num_nodes]
@note result[i, 0] is for batch i and s=0
@note result[i, j]: in i-th sample the batch,
if j < #num_nodes: the prob of s = 0, and w = j
else: the prob of s = 0, and w = j - #num_nodes
"""
"""
_
,
l_0
=
self
.
PTM_0
(
e
,
q
,
need_l
=
True
)
# [#num_batch, #num_nodes]
_
,
l_0
=
self
.
PTM_0
(
e
,
q
,
need_l
=
True
)
# [#num_batch, #num_nodes]
_
,
l_1
=
self
.
PTM_1
(
e
,
q
,
need_l
=
True
)
# [#num_batch, #num_nodes]
_
,
l_1
=
self
.
PTM_1
(
e
,
q
,
need_l
=
True
)
# [#num_batch, #num_nodes]
...
@@ -198,13 +198,9 @@ class EPTM(nn.Module):
...
@@ -198,13 +198,9 @@ class EPTM(nn.Module):
mask
=
torch
.
stack
([
mask
,
mask
],
dim
=
1
)
mask
=
torch
.
stack
([
mask
,
mask
],
dim
=
1
)
l
=
torch
.
where
(
mask
==
False
,
l
,
-
torch
.
inf
)
l
=
torch
.
where
(
mask
==
False
,
l
,
-
torch
.
inf
)
l
=
torch
.
tanh
(
l
)
.
mul_
(
self
.
C
)
l
=
torch
.
tanh
(
l
)
.
mul_
(
self
.
C
)
p4w
=
torch
.
softmax
(
l
,
dim
=-
1
)
# probability of w, [#num_batch, 2, #num_nodes]
batch_size
=
l
.
shape
[
0
]
p4ws
=
torch
.
softmax
(
l
.
reshape
(
batch_size
,
-
1
),
dim
=-
1
)
e_mean
=
torch
.
mean
(
e
,
dim
=
1
)
.
squeeze
()
return
p4ws
s
=
torch
.
cat
([
e_mean
,
q
],
dim
=
1
)
s
=
self
.
q_encode
(
s
)
p4s
=
torch
.
softmax
(
s
,
dim
=-
1
)
return
p4w
,
p4s
class
Decoder
(
nn
.
Module
):
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
dim_e
=
D
,
dim_q
=
360
)
->
None
:
def
__init__
(
self
,
dim_e
=
D
,
dim_q
=
360
)
->
None
:
...
@@ -253,6 +249,6 @@ class Decoder(nn.Module):
...
@@ -253,6 +249,6 @@ class Decoder(nn.Module):
u
,
_
=
self
.
ptm
(
e
,
cur_q4u
,
mask
=
mask_visited
)
u
,
_
=
self
.
ptm
(
e
,
cur_q4u
,
mask
=
mask_visited
)
cur_q4w
=
self
.
q_gen
(
self
.
_last_edge
,
self
.
_last_subtree
,
u
)
cur_q4w
=
self
.
q_gen
(
self
.
_last_edge
,
self
.
_last_subtree
,
u
)
w
,
s
=
self
.
eptm
(
e
,
cur_q4w
,
mask
=
mask_unvisited
)
ws
=
self
.
eptm
(
e
,
cur_q4w
,
mask
=
mask_unvisited
)
return
u
,
w
,
s
return
u
,
ws
test.py
View file @
131c30fa
from
turtle
import
up
import
torch
import
torch
import
pdb
import
pdb
...
@@ -24,26 +23,23 @@ model = Policy(actor=actor_net,
...
@@ -24,26 +23,23 @@ model = Policy(actor=actor_net,
action_space
=
env
.
action_space
action_space
=
env
.
action_space
)
)
u0
,
_
=
model
.
first_act
(
nodes
)
u0
=
model
.
first_act
(
nodes
)
u_list
=
[
u0
]
v_list
=
[
u0
]
w_list
=
[]
h_list
=
[]
s_list
=
[]
mask_visited
=
torch
.
zeros
(
batch_size
,
num_nodes
)
.
bool
()
mask_visited
=
torch
.
zeros
(
batch_size
,
num_nodes
)
.
bool
()
mask_visited
=
update_mask
(
mask_visited
,
batch_size
,
u0
)
mask_visited
=
update_mask
(
mask_visited
,
batch_size
,
u0
)
for
i
in
range
(
1
,
num_nodes
):
for
i
in
range
(
1
,
num_nodes
):
u
,
w
,
s
,
u_dist
,
w_dist
,
s_dist
=
model
.
act
(
nodes
,
mask_visited
)
v
,
h
=
model
.
act
(
nodes
,
mask_visited
)
mask_visited
=
update_mask
(
mask_visited
,
batch_size
,
u
)
mask_visited
=
update_mask
(
mask_visited
,
batch_size
,
v
)
mask_visited
=
update_mask
(
mask_visited
,
batch_size
,
w
)
mask_visited
=
update_mask
(
mask_visited
,
batch_size
,
h
)
u_list
.
append
(
u
)
v_list
.
append
(
v
)
w_list
.
append
(
w
)
h_list
.
append
(
h
)
s_list
.
append
(
s
)
# transpose into [#num_batch, #num_nodes]
# transpose into [#num_batch, #num_nodes]
all_u
=
torch
.
stack
(
u_list
)
.
transpose
(
1
,
0
)
all_v
=
torch
.
stack
(
v_list
)
.
transpose
(
1
,
0
)
all_w
=
torch
.
stack
(
w_list
)
.
transpose
(
1
,
0
)
all_h
=
torch
.
stack
(
h_list
)
.
transpose
(
1
,
0
)
all_s
=
torch
.
stack
(
s_list
)
.
transpose
(
1
,
0
)
pdb
.
set_trace
()
pdb
.
set_trace
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment