Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
R
REST
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
lvzhengyang
REST
Commits
131c30fa
Commit
131c30fa
authored
Sep 02, 2022
by
lvzhengyang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix errors on EPTM module in returning w and s
parent
592f1b41
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
56 deletions
+68
-56
__pycache__/agent.cpython-38.pyc
+0
-0
__pycache__/model.cpython-38.pyc
+0
-0
agent.py
+48
-29
model.py
+9
-13
test.py
+11
-14
No files found.
__pycache__/agent.cpython-38.pyc
View file @
131c30fa
No preview for this file type
__pycache__/model.cpython-38.pyc
View file @
131c30fa
No preview for this file type
agent.py
View file @
131c30fa
...
...
@@ -38,8 +38,8 @@ class Actor(nn.Module):
def
forward
(
self
,
nodes
,
mask_visited
=
None
,
mask_unvisited
=
None
):
e
=
self
.
encoder
(
nodes
)
u_probs
,
w
_probs
,
s_probs
=
self
.
decoder
(
e
,
mask_visited
,
mask_unvisited
)
return
u_probs
,
w
_probs
,
s_probs
u_probs
,
ws_probs
=
self
.
decoder
(
e
,
mask_visited
,
mask_unvisited
)
return
u_probs
,
ws_probs
class
Critic
(
nn
.
Module
):
def
__init__
(
self
,
dim_e
=
D
,
dim_c
=
256
)
->
None
:
...
...
@@ -57,6 +57,7 @@ class Critic(nn.Module):
def
forward
(
self
,
nodes
):
"""
@param nodes: [#batch, #num_nodes, 2] in dtype float
@return Expection for each batch, [#num_batch]
"""
e
=
self
.
encoder
(
nodes
)
# [#batch_size, #num_nodes, D]
glimpse_w
=
self
.
g
(
torch
.
tanh
(
e
))
.
squeeze_
()
# [#batch_size, #num_nodes]
...
...
@@ -81,11 +82,17 @@ class Policy():
self
.
actor
=
actor
self
.
critic
=
critic
self
.
log_probs_buf
=
{
"u0"
:
None
,
"u"
:
[],
"ws"
:
[]
}
def
first_act
(
self
,
nodes
):
"""
@brief perform action for t == 0, with the query = 0, get u0
@param nodes: [#num_batch, #num_nodes, 2]
@return points: [#num_batch],
dist of probs [#num_batch, #num_nodes
]
@return points: [#num_batch],
and the corresponding log_prob [#num_batch
]
"""
probs
=
self
.
actor
.
get_u0_probs
(
nodes
)
# [#num_batch, #num_nodes]
dist
=
Categorical
(
probs
)
...
...
@@ -96,7 +103,8 @@ class Policy():
_last_Eu
.
append
(
self
.
actor
.
_node_e
[
i
,
u0
[
i
]])
self
.
actor
.
decoder
.
_last_Eu
=
torch
.
stack
(
_last_Eu
)
self
.
actor
.
decoder
.
_last_Ev
=
self
.
actor
.
decoder
.
_last_Eu
.
clone
()
return
u0
,
dist
self
.
log_probs_buf
[
'u0'
]
=
dist
.
log_prob
(
u0
)
return
u0
def
act
(
self
,
nodes
,
mask_visited
=
None
,
mask_unvisited
=
None
):
"""
...
...
@@ -104,50 +112,61 @@ class Policy():
@param nodes: [#num_batch, #num_nodes, 2]
@param mask_visited/mask_unvisited: [#num_batch, #num_nodes]
TODO: gather the input into one obs: contains a batch of obs
@return u/w/s: [#num_batch], and their corresponding log_prob [#num_batch]
"""
if
mask_visited
==
None
and
not
mask_unvisited
==
None
:
mask_visited
=
~
mask_unvisited
if
mask_unvisited
==
None
and
not
mask_visited
==
None
:
mask_unvisited
=
~
mask_visited
u_probs
,
_w_probs
,
s_probs
=
self
.
actor
(
nodes
,
num_nodes
=
nodes
.
shape
[
1
]
u_probs
,
ws_probs
=
self
.
actor
(
nodes
,
mask_visited
=
mask_visited
,
mask_unvisited
=
mask_unvisited
)
u_dist
=
Categorical
(
u_probs
)
_w_dist
=
Categorical
(
_w_probs
)
# wait to be choice by s
s_dist
=
Categorical
(
s_probs
)
ws_dist
=
Categorical
(
ws_probs
)
u
=
u_dist
.
sample
()
_w
=
_w_dist
.
sample
()
s
=
s_dist
.
sample
()
_w
=
ws_dist
.
sample
()
self
.
log_probs_buf
[
'u'
]
.
append
(
u_dist
.
log_prob
(
u
))
self
.
log_probs_buf
[
'ws'
]
.
append
(
ws_dist
.
log_prob
(
_w
))
s
=
torch
.
where
(
_w
<
num_nodes
,
int
(
0
),
int
(
1
))
.
to
(
_w
.
device
)
w
=
torch
.
where
(
_w
<
num_nodes
,
_w
,
_w
-
num_nodes
)
batch_size
=
u
.
shape
[
0
]
_last_Eu
=
[]
for
i
in
range
(
batch_size
):
_last_Eu
.
append
(
self
.
actor
.
_node_e
[
i
,
u
[
i
]])
self
.
actor
.
decoder
.
_last_Eu
=
torch
.
stack
(
_last_Eu
)
_last_Ew
=
[]
_last_Ev
=
[]
_last_Eh
=
[]
w
=
[]
w_probs
=
[]
v
=
[]
h
=
[]
for
i
in
range
(
batch_size
):
_last_Eu
.
append
(
self
.
actor
.
_node_e
[
i
,
u
[
i
]])
_last_Ew
.
append
(
self
.
actor
.
_node_e
[
i
,
w
[
i
]])
if
s
[
i
]
==
0
:
_last_Ev
.
append
(
self
.
actor
.
_node_e
[
i
,
u
[
i
]])
_last_Eh
.
append
(
self
.
actor
.
_node_e
[
i
,
_w
[
i
][
0
]])
_last_Ew
.
append
(
self
.
actor
.
_node_e
[
i
,
_w
[
i
][
0
]])
w
.
append
(
_w
[
i
,
0
])
w_probs
.
append
(
_w_probs
[
i
,
0
])
_last_Eh
.
append
(
self
.
actor
.
_node_e
[
i
,
w
[
i
]])
v
.
append
(
u
[
i
])
h
.
append
(
w
[
i
])
else
:
_last_Ev
.
append
(
self
.
actor
.
_node_e
[
i
,
_w
[
i
][
1
]])
_last_Ev
.
append
(
self
.
actor
.
_node_e
[
i
,
w
[
i
]])
_last_Eh
.
append
(
self
.
actor
.
_node_e
[
i
,
u
[
i
]])
_last_Ew
.
append
(
self
.
actor
.
_node_e
[
i
,
_w
[
i
][
1
]])
w
.
append
(
_w
[
i
,
1
])
w_probs
.
append
(
_w_probs
[
i
,
0
])
v
.
append
(
w
[
i
])
h
.
append
(
u
[
i
])
self
.
actor
.
decoder
.
_last_Eu
=
torch
.
stack
(
_last_Eu
)
self
.
actor
.
decoder
.
_last_Ew
=
torch
.
stack
(
_last_Ew
)
self
.
actor
.
decoder
.
_last_Ev
=
torch
.
stack
(
_last_Ev
)
self
.
actor
.
decoder
.
_last_Eh
=
torch
.
stack
(
_last_Eh
)
self
.
actor
.
decoder
.
_last_Ew
=
torch
.
stack
(
_last_Ew
)
# get the choiced w
w
=
torch
.
tensor
(
w
,
device
=
u
.
device
)
w_probs
=
torch
.
stack
(
w_probs
)
w_dist
=
Categorical
(
w_probs
)
return
u
,
w
,
s
,
u_dist
,
w_dist
,
s_dist
v
=
torch
.
stack
(
v
)
h
=
torch
.
stack
(
h
)
return
v
,
h
def
learn
(
self
,
rewards
):
"""
@brief update the parameters of actor and critic
@param rewards: reward of an episode in a batch. [#num_batch]
@note the env has returned 'done' signal
"""
pass
model.py
View file @
131c30fa
...
...
@@ -177,8 +177,6 @@ class EPTM(nn.Module):
super
(
EPTM
,
self
)
.
__init__
()
self
.
PTM_0
=
PTM
(
dim_e
=
dim_e
,
dim_q
=
dim_q
)
self
.
PTM_1
=
PTM
(
dim_e
=
dim_e
,
dim_q
=
dim_q
)
# to get s
self
.
q_encode
=
nn
.
Linear
(
dim_q
+
dim_e
,
2
,
bias
=
False
)
self
.
C
=
10.0
def
forward
(
self
,
e
,
q
,
mask
=
None
):
...
...
@@ -187,8 +185,10 @@ class EPTM(nn.Module):
@param q: [#num_batch, 360]
@param mask: [#num_batch, #num_nodes], with masked points set to True and the reset are False
set unvisited points True
@return [#num_batch, 2, #num_nodes]
@note result[i, 0] is for batch i and s=0
@return [#num_batch, 2 * #num_nodes]
@note result[i, j]: in i-th sample the batch,
if j < #num_nodes: the prob of s = 0, and w = j
else: the prob of s = 0, and w = j - #num_nodes
"""
_
,
l_0
=
self
.
PTM_0
(
e
,
q
,
need_l
=
True
)
# [#num_batch, #num_nodes]
_
,
l_1
=
self
.
PTM_1
(
e
,
q
,
need_l
=
True
)
# [#num_batch, #num_nodes]
...
...
@@ -198,13 +198,9 @@ class EPTM(nn.Module):
mask
=
torch
.
stack
([
mask
,
mask
],
dim
=
1
)
l
=
torch
.
where
(
mask
==
False
,
l
,
-
torch
.
inf
)
l
=
torch
.
tanh
(
l
)
.
mul_
(
self
.
C
)
p4w
=
torch
.
softmax
(
l
,
dim
=-
1
)
# probability of w, [#num_batch, 2, #num_nodes]
e_mean
=
torch
.
mean
(
e
,
dim
=
1
)
.
squeeze
()
s
=
torch
.
cat
([
e_mean
,
q
],
dim
=
1
)
s
=
self
.
q_encode
(
s
)
p4s
=
torch
.
softmax
(
s
,
dim
=-
1
)
return
p4w
,
p4s
batch_size
=
l
.
shape
[
0
]
p4ws
=
torch
.
softmax
(
l
.
reshape
(
batch_size
,
-
1
),
dim
=-
1
)
return
p4ws
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
dim_e
=
D
,
dim_q
=
360
)
->
None
:
...
...
@@ -253,6 +249,6 @@ class Decoder(nn.Module):
u
,
_
=
self
.
ptm
(
e
,
cur_q4u
,
mask
=
mask_visited
)
cur_q4w
=
self
.
q_gen
(
self
.
_last_edge
,
self
.
_last_subtree
,
u
)
w
,
s
=
self
.
eptm
(
e
,
cur_q4w
,
mask
=
mask_unvisited
)
ws
=
self
.
eptm
(
e
,
cur_q4w
,
mask
=
mask_unvisited
)
return
u
,
w
,
s
return
u
,
ws
test.py
View file @
131c30fa
from
turtle
import
up
import
torch
import
pdb
...
...
@@ -24,26 +23,23 @@ model = Policy(actor=actor_net,
action_space
=
env
.
action_space
)
u0
,
_
=
model
.
first_act
(
nodes
)
u_list
=
[
u0
]
w_list
=
[]
s_list
=
[]
u0
=
model
.
first_act
(
nodes
)
v_list
=
[
u0
]
h_list
=
[]
mask_visited
=
torch
.
zeros
(
batch_size
,
num_nodes
)
.
bool
()
mask_visited
=
update_mask
(
mask_visited
,
batch_size
,
u0
)
for
i
in
range
(
1
,
num_nodes
):
u
,
w
,
s
,
u_dist
,
w_dist
,
s_dist
=
model
.
act
(
nodes
,
mask_visited
)
mask_visited
=
update_mask
(
mask_visited
,
batch_size
,
u
)
mask_visited
=
update_mask
(
mask_visited
,
batch_size
,
w
)
u_list
.
append
(
u
)
w_list
.
append
(
w
)
s_list
.
append
(
s
)
v
,
h
=
model
.
act
(
nodes
,
mask_visited
)
mask_visited
=
update_mask
(
mask_visited
,
batch_size
,
v
)
mask_visited
=
update_mask
(
mask_visited
,
batch_size
,
h
)
v_list
.
append
(
v
)
h_list
.
append
(
h
)
# transpose into [#num_batch, #num_nodes]
all_u
=
torch
.
stack
(
u_list
)
.
transpose
(
1
,
0
)
all_w
=
torch
.
stack
(
w_list
)
.
transpose
(
1
,
0
)
all_s
=
torch
.
stack
(
s_list
)
.
transpose
(
1
,
0
)
all_v
=
torch
.
stack
(
v_list
)
.
transpose
(
1
,
0
)
all_h
=
torch
.
stack
(
h_list
)
.
transpose
(
1
,
0
)
pdb
.
set_trace
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment