Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
R
REST
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
lvzhengyang
REST
Commits
7f1d8bef
Commit
7f1d8bef
authored
Sep 02, 2022
by
lvzhengyang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add "learn" part for Policy
parent
131c30fa
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
75 additions
and
17 deletions
+75
-17
__pycache__/agent.cpython-38.pyc
+0
-0
__pycache__/model.cpython-38.pyc
+0
-0
agent.py
+47
-7
model.py
+24
-10
test.py
+4
-0
No files found.
__pycache__/agent.cpython-38.pyc
View file @
7f1d8bef
No preview for this file type
__pycache__/model.cpython-38.pyc
View file @
7f1d8bef
No preview for this file type
agent.py
View file @
7f1d8bef
...
@@ -60,7 +60,7 @@ class Critic(nn.Module):
...
@@ -60,7 +60,7 @@ class Critic(nn.Module):
@return Expection for each batch, [#num_batch]
@return Expection for each batch, [#num_batch]
"""
"""
e
=
self
.
encoder
(
nodes
)
# [#batch_size, #num_nodes, D]
e
=
self
.
encoder
(
nodes
)
# [#batch_size, #num_nodes, D]
glimpse_w
=
self
.
g
(
torch
.
tanh
(
e
))
.
squeeze
_
()
# [#batch_size, #num_nodes]
glimpse_w
=
self
.
g
(
torch
.
tanh
(
e
))
.
squeeze
()
# [#batch_size, #num_nodes]
glimpse_w
=
torch
.
softmax
(
glimpse_w
,
dim
=
1
)
.
unsqueeze
(
1
)
# [#batch_size, 1, #num_nodes]
glimpse_w
=
torch
.
softmax
(
glimpse_w
,
dim
=
1
)
.
unsqueeze
(
1
)
# [#batch_size, 1, #num_nodes]
glimpse
=
torch
.
bmm
(
glimpse_w
,
e
)
.
squeeze
()
# [#batch_size, D]
glimpse
=
torch
.
bmm
(
glimpse_w
,
e
)
.
squeeze
()
# [#batch_size, D]
b
=
self
.
final
(
glimpse
)
.
squeeze
()
# [#batch_size]
b
=
self
.
final
(
glimpse
)
.
squeeze
()
# [#batch_size]
...
@@ -70,17 +70,26 @@ class Critic(nn.Module):
...
@@ -70,17 +70,26 @@ class Critic(nn.Module):
the action_space is of gym.spaces.Discrete
the action_space is of gym.spaces.Discrete
the agent outputs probabilities, use torch.distributions.categorical.Categorical()
the agent outputs probabilities, use torch.distributions.categorical.Categorical()
"""
"""
class
Policy
():
class
Policy
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
actor
,
actor
,
critic
,
critic
,
obs_space
,
obs_space
,
action_space
,
action_space
,
optimizer
=
None
,
lr
=
2.5e-4
,
weight_decay
=
5e-4
,
loss_fn_critic
=
torch
.
nn
.
MSELoss
()
)
->
None
:
)
->
None
:
self
.
obs_space
=
obs_space
super
(
Policy
,
self
)
.
__init__
()
self
.
action_space
=
action_space
self
.
actor
=
actor
self
.
actor
=
actor
self
.
critic
=
critic
self
.
critic
=
critic
self
.
obs_space
=
obs_space
self
.
action_space
=
action_space
if
optimizer
==
None
:
self
.
optimizer
=
torch
.
optim
.
Adam
(
self
.
parameters
(),
lr
=
lr
,
weight_decay
=
weight_decay
)
self
.
loss_fn_critic
=
loss_fn_critic
self
.
log_probs_buf
=
{
self
.
log_probs_buf
=
{
"u0"
:
None
,
"u0"
:
None
,
...
@@ -88,6 +97,9 @@ class Policy():
...
@@ -88,6 +97,9 @@ class Policy():
"ws"
:
[]
"ws"
:
[]
}
}
def
forward
(
self
,
x
):
raise
NotImplementedError
def
first_act
(
self
,
nodes
):
def
first_act
(
self
,
nodes
):
"""
"""
@brief perform action for t == 0, with the query = 0, get u0
@brief perform action for t == 0, with the query = 0, get u0
...
@@ -162,11 +174,39 @@ class Policy():
...
@@ -162,11 +174,39 @@ class Policy():
h
=
torch
.
stack
(
h
)
h
=
torch
.
stack
(
h
)
return
v
,
h
return
v
,
h
def
learn
(
self
,
rewards
):
def
learn
(
self
,
nodes
,
rewards
):
"""
"""
@brief update the parameters of actor and critic
@brief update the parameters of actor and critic via REINFORECE algo.
@param nodes: [#batch, #num_nodes, 2] in dtype float
@param rewards: reward of an episode in a batch. [#num_batch]
@param rewards: reward of an episode in a batch. [#num_batch]
@note the env has returned 'done' signal
@note the env has returned 'done' signal
"""
"""
pass
self
.
optimizer
.
zero_grad
()
# cal log(p_res) for each trajectory in the batch
u0
=
self
.
log_probs_buf
[
'u0'
]
# [#num_batch]
u
=
torch
.
stack
(
self
.
log_probs_buf
[
'u'
])
.
transpose
(
1
,
0
)
# [#num_batch, #num_nodes - 1]
ws
=
torch
.
stack
(
self
.
log_probs_buf
[
'ws'
])
.
transpose
(
1
,
0
)
# [#num_batch, #num_nodes - 1]
# p_res: [#num_batch]
p_res
=
torch
.
add
(
u
,
ws
)
.
sum
(
dim
=-
1
)
.
add
(
u0
)
with
torch
.
no_grad
():
baselines
=
self
.
critic
(
nodes
)
j
=
(
baselines
-
rewards
)
*
p_res
j
=
j
.
mean
()
baselines
=
self
.
critic
(
nodes
)
loss_critic
=
self
.
loss_fn_critic
(
baselines
,
rewards
)
j
.
backward
()
loss_critic
.
backward
()
self
.
optimizer
.
step
()
# Finally, reset the buf
self
.
log_probs_buf
=
{
"u0"
:
None
,
"u"
:
[],
"ws"
:
[]
}
model.py
View file @
7f1d8bef
...
@@ -45,11 +45,18 @@ class EncoderLayer(nn.Module):
...
@@ -45,11 +45,18 @@ class EncoderLayer(nn.Module):
k
=
self
.
key
(
x
)
k
=
self
.
key
(
x
)
v
=
self
.
value
(
x
)
v
=
self
.
value
(
x
)
x1
,
_
=
self
.
attention
(
query
=
q
,
key
=
k
,
value
=
v
,
need_weights
=
False
)
x1
,
_
=
self
.
attention
(
query
=
q
,
key
=
k
,
value
=
v
,
need_weights
=
False
)
x
.
add_
(
x1
)
x
=
x
.
add
(
x1
)
x
=
self
.
norm
(
x
.
view
(
-
1
,
x
.
shape
[
-
1
]))
.
reshape
(
batch_size
,
num_nodes
,
-
1
)
# x = self.norm(x.reshape(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1)
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
norm
(
x
)
x
=
x
.
transpose
(
1
,
2
)
x1
=
self
.
feed_foward
(
x
)
x1
=
self
.
feed_foward
(
x
)
x
.
add_
(
x1
)
x
=
x
.
add
(
x1
)
x
=
self
.
norm
(
x
.
view
(
-
1
,
x
.
shape
[
-
1
]))
.
reshape
(
batch_size
,
num_nodes
,
-
1
)
# x = self.norm(x.reshape(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1)
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
norm
(
x
)
x
=
x
.
transpose
(
1
,
2
)
return
x
return
x
class
Encoder
(
nn
.
Module
):
class
Encoder
(
nn
.
Module
):
...
@@ -58,7 +65,6 @@ class Encoder(nn.Module):
...
@@ -58,7 +65,6 @@ class Encoder(nn.Module):
self
.
dim
=
dim
self
.
dim
=
dim
self
.
N
=
N
self
.
N
=
N
self
.
W_emb
=
nn
.
Linear
(
2
,
self
.
dim
,
bias
=
False
)
# get E_0
self
.
W_emb
=
nn
.
Linear
(
2
,
self
.
dim
,
bias
=
False
)
# get E_0
# TODO: implement batch norm
self
.
encoder_layers
=
nn
.
Sequential
()
self
.
encoder_layers
=
nn
.
Sequential
()
for
i
in
range
(
self
.
N
):
for
i
in
range
(
self
.
N
):
self
.
encoder_layers
.
add_module
(
'{}_{}'
.
format
(
EncoderLayer
.
__name__
,
i
),
self
.
encoder_layers
.
add_module
(
'{}_{}'
.
format
(
EncoderLayer
.
__name__
,
i
),
...
@@ -72,7 +78,15 @@ class Encoder(nn.Module):
...
@@ -72,7 +78,15 @@ class Encoder(nn.Module):
batch_size
=
x
.
shape
[
0
]
batch_size
=
x
.
shape
[
0
]
num_nodes
=
x
.
shape
[
1
]
num_nodes
=
x
.
shape
[
1
]
x
=
self
.
W_emb
(
x
)
x
=
self
.
W_emb
(
x
)
x
=
self
.
norm
(
x
.
view
(
-
1
,
x
.
shape
[
-
1
]))
.
reshape
(
batch_size
,
num_nodes
,
-
1
)
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
norm
(
x
)
x
=
x
.
transpose
(
1
,
2
)
"""
dim = x.shape[-1]
x = x.view(-1, dim)
x = self.norm(x)
x = x.reshape(batch_size, num_nodes, -1)
"""
x
=
self
.
encoder_layers
(
x
)
x
=
self
.
encoder_layers
(
x
)
return
x
return
x
...
@@ -103,15 +117,15 @@ class PTM(nn.Module):
...
@@ -103,15 +117,15 @@ class PTM(nn.Module):
# e: [#num_batch, #num_nodes, 360]
# e: [#num_batch, #num_nodes, 360]
# q: [#num_batch, , 360]
# q: [#num_batch, , 360]
# for each batch, perform e + q (q is broadcasted in to #num_nodes)
# for each batch, perform e + q (q is broadcasted in to #num_nodes)
e
.
transpose_
(
0
,
1
)
.
add_
(
q
)
.
transpose_
(
0
,
1
)
e
=
e
.
transpose_
(
0
,
1
)
.
add
(
q
)
.
transpose
(
0
,
1
)
e
=
self
.
W_g
(
e
)
# get l
e
=
self
.
W_g
(
e
)
# get l
e
.
squeeze_
()
e
=
e
.
squeeze
()
l
=
e
.
clone
()
l
=
e
.
clone
()
if
mask
!=
None
:
if
mask
!=
None
:
# points be masked is set to be -INF
# points be masked is set to be -INF
e
=
torch
.
where
(
mask
==
False
,
e
,
-
torch
.
inf
)
e
=
torch
.
where
(
mask
==
False
,
e
,
-
torch
.
inf
)
e
=
torch
.
tanh
(
e
)
e
=
torch
.
tanh
(
e
)
e
.
mul_
(
self
.
C
)
e
=
e
.
mul
(
self
.
C
)
p
=
nn
.
functional
.
softmax
(
e
,
dim
=
1
)
p
=
nn
.
functional
.
softmax
(
e
,
dim
=
1
)
if
need_l
:
if
need_l
:
return
p
,
l
return
p
,
l
...
@@ -197,7 +211,7 @@ class EPTM(nn.Module):
...
@@ -197,7 +211,7 @@ class EPTM(nn.Module):
if
mask
.
dim
()
==
2
:
if
mask
.
dim
()
==
2
:
mask
=
torch
.
stack
([
mask
,
mask
],
dim
=
1
)
mask
=
torch
.
stack
([
mask
,
mask
],
dim
=
1
)
l
=
torch
.
where
(
mask
==
False
,
l
,
-
torch
.
inf
)
l
=
torch
.
where
(
mask
==
False
,
l
,
-
torch
.
inf
)
l
=
torch
.
tanh
(
l
)
.
mul
_
(
self
.
C
)
l
=
torch
.
tanh
(
l
)
.
mul
(
self
.
C
)
batch_size
=
l
.
shape
[
0
]
batch_size
=
l
.
shape
[
0
]
p4ws
=
torch
.
softmax
(
l
.
reshape
(
batch_size
,
-
1
),
dim
=-
1
)
p4ws
=
torch
.
softmax
(
l
.
reshape
(
batch_size
,
-
1
),
dim
=-
1
)
return
p4ws
return
p4ws
...
...
test.py
View file @
7f1d8bef
...
@@ -42,4 +42,7 @@ for i in range(1, num_nodes):
...
@@ -42,4 +42,7 @@ for i in range(1, num_nodes):
all_v
=
torch
.
stack
(
v_list
)
.
transpose
(
1
,
0
)
all_v
=
torch
.
stack
(
v_list
)
.
transpose
(
1
,
0
)
all_h
=
torch
.
stack
(
h_list
)
.
transpose
(
1
,
0
)
all_h
=
torch
.
stack
(
h_list
)
.
transpose
(
1
,
0
)
rewards
=
-
10
*
torch
.
ones
(
batch_size
)
model
.
learn
(
nodes
,
rewards
)
pdb
.
set_trace
()
pdb
.
set_trace
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment