Commit 131c30fa by lvzhengyang

fix errors on EPTM module in returning w and s

parent 592f1b41
...@@ -38,8 +38,8 @@ class Actor(nn.Module): ...@@ -38,8 +38,8 @@ class Actor(nn.Module):
def forward(self, nodes, mask_visited=None, mask_unvisited=None): def forward(self, nodes, mask_visited=None, mask_unvisited=None):
e = self.encoder(nodes) e = self.encoder(nodes)
u_probs, w_probs, s_probs = self.decoder(e, mask_visited, mask_unvisited) u_probs, ws_probs = self.decoder(e, mask_visited, mask_unvisited)
return u_probs, w_probs, s_probs return u_probs, ws_probs
class Critic(nn.Module): class Critic(nn.Module):
def __init__(self, dim_e=D, dim_c=256) -> None: def __init__(self, dim_e=D, dim_c=256) -> None:
...@@ -57,6 +57,7 @@ class Critic(nn.Module): ...@@ -57,6 +57,7 @@ class Critic(nn.Module):
def forward(self, nodes): def forward(self, nodes):
""" """
@param nodes: [#batch, #num_nodes, 2] in dtype float @param nodes: [#batch, #num_nodes, 2] in dtype float
@return Expection for each batch, [#num_batch]
""" """
e = self.encoder(nodes) # [#batch_size, #num_nodes, D] e = self.encoder(nodes) # [#batch_size, #num_nodes, D]
glimpse_w = self.g(torch.tanh(e)).squeeze_() # [#batch_size, #num_nodes] glimpse_w = self.g(torch.tanh(e)).squeeze_() # [#batch_size, #num_nodes]
...@@ -81,11 +82,17 @@ class Policy(): ...@@ -81,11 +82,17 @@ class Policy():
self.actor = actor self.actor = actor
self.critic = critic self.critic = critic
self.log_probs_buf = {
"u0": None,
"u": [],
"ws": []
}
def first_act(self, nodes): def first_act(self, nodes):
""" """
@brief perform action for t == 0, with the query = 0, get u0 @brief perform action for t == 0, with the query = 0, get u0
@param nodes: [#num_batch, #num_nodes, 2] @param nodes: [#num_batch, #num_nodes, 2]
@return points: [#num_batch], dist of probs [#num_batch, #num_nodes] @return points: [#num_batch], and the corresponding log_prob [#num_batch]
""" """
probs = self.actor.get_u0_probs(nodes) # [#num_batch, #num_nodes] probs = self.actor.get_u0_probs(nodes) # [#num_batch, #num_nodes]
dist = Categorical(probs) dist = Categorical(probs)
...@@ -96,7 +103,8 @@ class Policy(): ...@@ -96,7 +103,8 @@ class Policy():
_last_Eu.append(self.actor._node_e[i, u0[i]]) _last_Eu.append(self.actor._node_e[i, u0[i]])
self.actor.decoder._last_Eu = torch.stack(_last_Eu) self.actor.decoder._last_Eu = torch.stack(_last_Eu)
self.actor.decoder._last_Ev = self.actor.decoder._last_Eu.clone() self.actor.decoder._last_Ev = self.actor.decoder._last_Eu.clone()
return u0, dist self.log_probs_buf['u0'] = dist.log_prob(u0)
return u0
def act(self, nodes, mask_visited=None, mask_unvisited=None): def act(self, nodes, mask_visited=None, mask_unvisited=None):
""" """
...@@ -104,50 +112,61 @@ class Policy(): ...@@ -104,50 +112,61 @@ class Policy():
@param nodes: [#num_batch, #num_nodes, 2] @param nodes: [#num_batch, #num_nodes, 2]
@param mask_visited/mask_unvisited: [#num_batch, #num_nodes] @param mask_visited/mask_unvisited: [#num_batch, #num_nodes]
TODO: gather the input into one obs: contains a batch of obs TODO: gather the input into one obs: contains a batch of obs
@return u/w/s: [#num_batch], and their corresponding log_prob [#num_batch]
""" """
if mask_visited == None and not mask_unvisited == None: if mask_visited == None and not mask_unvisited == None:
mask_visited = ~mask_unvisited mask_visited = ~mask_unvisited
if mask_unvisited == None and not mask_visited == None: if mask_unvisited == None and not mask_visited == None:
mask_unvisited = ~mask_visited mask_unvisited = ~mask_visited
u_probs, _w_probs, s_probs = self.actor(nodes,
num_nodes = nodes.shape[1]
u_probs, ws_probs = self.actor(nodes,
mask_visited=mask_visited, mask_unvisited=mask_unvisited) mask_visited=mask_visited, mask_unvisited=mask_unvisited)
u_dist = Categorical(u_probs) u_dist = Categorical(u_probs)
_w_dist = Categorical(_w_probs) # wait to be choice by s ws_dist = Categorical(ws_probs)
s_dist = Categorical(s_probs)
u = u_dist.sample() u = u_dist.sample()
_w = _w_dist.sample() _w = ws_dist.sample()
s = s_dist.sample()
self.log_probs_buf['u'].append(u_dist.log_prob(u))
self.log_probs_buf['ws'].append(ws_dist.log_prob(_w))
s = torch.where(_w < num_nodes, int(0), int(1)).to(_w.device)
w = torch.where(_w < num_nodes, _w, _w - num_nodes)
batch_size = u.shape[0] batch_size = u.shape[0]
_last_Eu = [] _last_Eu = []
for i in range(batch_size):
_last_Eu.append(self.actor._node_e[i, u[i]])
self.actor.decoder._last_Eu = torch.stack(_last_Eu)
_last_Ew = [] _last_Ew = []
_last_Ev = [] _last_Ev = []
_last_Eh = [] _last_Eh = []
w = [] v = []
w_probs = [] h = []
for i in range(batch_size): for i in range(batch_size):
_last_Eu.append(self.actor._node_e[i, u[i]])
_last_Ew.append(self.actor._node_e[i, w[i]])
if s[i] == 0: if s[i] == 0:
_last_Ev.append(self.actor._node_e[i, u[i]]) _last_Ev.append(self.actor._node_e[i, u[i]])
_last_Eh.append(self.actor._node_e[i, _w[i][0]]) _last_Eh.append(self.actor._node_e[i, w[i]])
_last_Ew.append(self.actor._node_e[i, _w[i][0]]) v.append(u[i])
w.append(_w[i, 0]) h.append(w[i])
w_probs.append(_w_probs[i, 0])
else: else:
_last_Ev.append(self.actor._node_e[i, _w[i][1]]) _last_Ev.append(self.actor._node_e[i, w[i]])
_last_Eh.append(self.actor._node_e[i, u[i]]) _last_Eh.append(self.actor._node_e[i, u[i]])
_last_Ew.append(self.actor._node_e[i, _w[i][1]]) v.append(w[i])
w.append(_w[i, 1]) h.append(u[i])
w_probs.append(_w_probs[i, 0]) self.actor.decoder._last_Eu = torch.stack(_last_Eu)
self.actor.decoder._last_Ew = torch.stack(_last_Ew)
self.actor.decoder._last_Ev = torch.stack(_last_Ev) self.actor.decoder._last_Ev = torch.stack(_last_Ev)
self.actor.decoder._last_Eh = torch.stack(_last_Eh) self.actor.decoder._last_Eh = torch.stack(_last_Eh)
self.actor.decoder._last_Ew = torch.stack(_last_Ew)
# get the choiced w
w = torch.tensor(w, device=u.device)
w_probs = torch.stack(w_probs)
w_dist = Categorical(w_probs)
return u, w, s, u_dist, w_dist, s_dist v = torch.stack(v)
h = torch.stack(h)
return v, h
def learn(self, rewards):
"""
@brief update the parameters of actor and critic
@param rewards: reward of an episode in a batch. [#num_batch]
@note the env has returned 'done' signal
"""
pass
...@@ -177,8 +177,6 @@ class EPTM(nn.Module): ...@@ -177,8 +177,6 @@ class EPTM(nn.Module):
super(EPTM, self).__init__() super(EPTM, self).__init__()
self.PTM_0 = PTM(dim_e=dim_e, dim_q=dim_q) self.PTM_0 = PTM(dim_e=dim_e, dim_q=dim_q)
self.PTM_1 = PTM(dim_e=dim_e, dim_q=dim_q) self.PTM_1 = PTM(dim_e=dim_e, dim_q=dim_q)
# to get s
self.q_encode = nn.Linear(dim_q + dim_e, 2, bias=False)
self.C = 10.0 self.C = 10.0
def forward(self, e, q, mask=None): def forward(self, e, q, mask=None):
...@@ -187,8 +185,10 @@ class EPTM(nn.Module): ...@@ -187,8 +185,10 @@ class EPTM(nn.Module):
@param q: [#num_batch, 360] @param q: [#num_batch, 360]
@param mask: [#num_batch, #num_nodes], with masked points set to True and the reset are False @param mask: [#num_batch, #num_nodes], with masked points set to True and the reset are False
set unvisited points True set unvisited points True
@return [#num_batch, 2, #num_nodes] @return [#num_batch, 2 * #num_nodes]
@note result[i, 0] is for batch i and s=0 @note result[i, j]: in i-th sample the batch,
if j < #num_nodes: the prob of s = 0, and w = j
else: the prob of s = 0, and w = j - #num_nodes
""" """
_, l_0 = self.PTM_0(e, q, need_l=True) # [#num_batch, #num_nodes] _, l_0 = self.PTM_0(e, q, need_l=True) # [#num_batch, #num_nodes]
_, l_1 = self.PTM_1(e, q, need_l=True) # [#num_batch, #num_nodes] _, l_1 = self.PTM_1(e, q, need_l=True) # [#num_batch, #num_nodes]
...@@ -198,13 +198,9 @@ class EPTM(nn.Module): ...@@ -198,13 +198,9 @@ class EPTM(nn.Module):
mask = torch.stack([mask, mask], dim=1) mask = torch.stack([mask, mask], dim=1)
l = torch.where(mask == False, l, -torch.inf) l = torch.where(mask == False, l, -torch.inf)
l = torch.tanh(l).mul_(self.C) l = torch.tanh(l).mul_(self.C)
p4w = torch.softmax(l, dim=-1) # probability of w, [#num_batch, 2, #num_nodes] batch_size = l.shape[0]
p4ws = torch.softmax(l.reshape(batch_size, -1), dim=-1)
e_mean = torch.mean(e, dim=1).squeeze() return p4ws
s = torch.cat([e_mean, q], dim=1)
s = self.q_encode(s)
p4s = torch.softmax(s ,dim=-1)
return p4w, p4s
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, dim_e=D, dim_q=360) -> None: def __init__(self, dim_e=D, dim_q=360) -> None:
...@@ -253,6 +249,6 @@ class Decoder(nn.Module): ...@@ -253,6 +249,6 @@ class Decoder(nn.Module):
u, _ = self.ptm(e, cur_q4u, mask=mask_visited) u, _ = self.ptm(e, cur_q4u, mask=mask_visited)
cur_q4w = self.q_gen(self._last_edge, self._last_subtree, u) cur_q4w = self.q_gen(self._last_edge, self._last_subtree, u)
w, s = self.eptm(e, cur_q4w, mask=mask_unvisited) ws = self.eptm(e, cur_q4w, mask=mask_unvisited)
return u, w, s return u, ws
from turtle import up
import torch import torch
import pdb import pdb
...@@ -24,26 +23,23 @@ model = Policy(actor=actor_net, ...@@ -24,26 +23,23 @@ model = Policy(actor=actor_net,
action_space=env.action_space action_space=env.action_space
) )
u0, _ = model.first_act(nodes) u0 = model.first_act(nodes)
u_list = [u0] v_list = [u0]
w_list = [] h_list = []
s_list = []
mask_visited = torch.zeros(batch_size, num_nodes).bool() mask_visited = torch.zeros(batch_size, num_nodes).bool()
mask_visited = update_mask(mask_visited, batch_size, u0) mask_visited = update_mask(mask_visited, batch_size, u0)
for i in range(1, num_nodes): for i in range(1, num_nodes):
u, w, s, u_dist, w_dist, s_dist = model.act(nodes, mask_visited) v, h = model.act(nodes, mask_visited)
mask_visited = update_mask(mask_visited, batch_size, u) mask_visited = update_mask(mask_visited, batch_size, v)
mask_visited = update_mask(mask_visited, batch_size, w) mask_visited = update_mask(mask_visited, batch_size, h)
u_list.append(u) v_list.append(v)
w_list.append(w) h_list.append(h)
s_list.append(s)
# transpose into [#num_batch, #num_nodes] # transpose into [#num_batch, #num_nodes]
all_u = torch.stack(u_list).transpose(1, 0) all_v = torch.stack(v_list).transpose(1, 0)
all_w = torch.stack(w_list).transpose(1, 0) all_h = torch.stack(h_list).transpose(1, 0)
all_s = torch.stack(s_list).transpose(1, 0)
pdb.set_trace() pdb.set_trace()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment