Commit 131c30fa by lvzhengyang

fix errors on EPTM module in returning w and s

parent 592f1b41
......@@ -38,8 +38,8 @@ class Actor(nn.Module):
def forward(self, nodes, mask_visited=None, mask_unvisited=None):
e = self.encoder(nodes)
u_probs, w_probs, s_probs = self.decoder(e, mask_visited, mask_unvisited)
return u_probs, w_probs, s_probs
u_probs, ws_probs = self.decoder(e, mask_visited, mask_unvisited)
return u_probs, ws_probs
class Critic(nn.Module):
def __init__(self, dim_e=D, dim_c=256) -> None:
......@@ -57,6 +57,7 @@ class Critic(nn.Module):
def forward(self, nodes):
"""
@param nodes: [#batch, #num_nodes, 2] in dtype float
@return Expection for each batch, [#num_batch]
"""
e = self.encoder(nodes) # [#batch_size, #num_nodes, D]
glimpse_w = self.g(torch.tanh(e)).squeeze_() # [#batch_size, #num_nodes]
......@@ -81,11 +82,17 @@ class Policy():
self.actor = actor
self.critic = critic
self.log_probs_buf = {
"u0": None,
"u": [],
"ws": []
}
def first_act(self, nodes):
"""
@brief perform action for t == 0, with the query = 0, get u0
@param nodes: [#num_batch, #num_nodes, 2]
@return points: [#num_batch], dist of probs [#num_batch, #num_nodes]
@return points: [#num_batch], and the corresponding log_prob [#num_batch]
"""
probs = self.actor.get_u0_probs(nodes) # [#num_batch, #num_nodes]
dist = Categorical(probs)
......@@ -96,7 +103,8 @@ class Policy():
_last_Eu.append(self.actor._node_e[i, u0[i]])
self.actor.decoder._last_Eu = torch.stack(_last_Eu)
self.actor.decoder._last_Ev = self.actor.decoder._last_Eu.clone()
return u0, dist
self.log_probs_buf['u0'] = dist.log_prob(u0)
return u0
def act(self, nodes, mask_visited=None, mask_unvisited=None):
"""
......@@ -104,50 +112,61 @@ class Policy():
@param nodes: [#num_batch, #num_nodes, 2]
@param mask_visited/mask_unvisited: [#num_batch, #num_nodes]
TODO: gather the input into one obs: contains a batch of obs
@return u/w/s: [#num_batch], and their corresponding log_prob [#num_batch]
"""
if mask_visited == None and not mask_unvisited == None:
mask_visited = ~mask_unvisited
if mask_unvisited == None and not mask_visited == None:
mask_unvisited = ~mask_visited
u_probs, _w_probs, s_probs = self.actor(nodes,
num_nodes = nodes.shape[1]
u_probs, ws_probs = self.actor(nodes,
mask_visited=mask_visited, mask_unvisited=mask_unvisited)
u_dist = Categorical(u_probs)
_w_dist = Categorical(_w_probs) # wait to be choice by s
s_dist = Categorical(s_probs)
ws_dist = Categorical(ws_probs)
u = u_dist.sample()
_w = _w_dist.sample()
s = s_dist.sample()
_w = ws_dist.sample()
self.log_probs_buf['u'].append(u_dist.log_prob(u))
self.log_probs_buf['ws'].append(ws_dist.log_prob(_w))
s = torch.where(_w < num_nodes, int(0), int(1)).to(_w.device)
w = torch.where(_w < num_nodes, _w, _w - num_nodes)
batch_size = u.shape[0]
_last_Eu = []
for i in range(batch_size):
_last_Eu.append(self.actor._node_e[i, u[i]])
self.actor.decoder._last_Eu = torch.stack(_last_Eu)
_last_Ew = []
_last_Ev = []
_last_Eh = []
w = []
w_probs = []
v = []
h = []
for i in range(batch_size):
_last_Eu.append(self.actor._node_e[i, u[i]])
_last_Ew.append(self.actor._node_e[i, w[i]])
if s[i] == 0:
_last_Ev.append(self.actor._node_e[i, u[i]])
_last_Eh.append(self.actor._node_e[i, _w[i][0]])
_last_Ew.append(self.actor._node_e[i, _w[i][0]])
w.append(_w[i, 0])
w_probs.append(_w_probs[i, 0])
_last_Eh.append(self.actor._node_e[i, w[i]])
v.append(u[i])
h.append(w[i])
else:
_last_Ev.append(self.actor._node_e[i, _w[i][1]])
_last_Ev.append(self.actor._node_e[i, w[i]])
_last_Eh.append(self.actor._node_e[i, u[i]])
_last_Ew.append(self.actor._node_e[i, _w[i][1]])
w.append(_w[i, 1])
w_probs.append(_w_probs[i, 0])
v.append(w[i])
h.append(u[i])
self.actor.decoder._last_Eu = torch.stack(_last_Eu)
self.actor.decoder._last_Ew = torch.stack(_last_Ew)
self.actor.decoder._last_Ev = torch.stack(_last_Ev)
self.actor.decoder._last_Eh = torch.stack(_last_Eh)
self.actor.decoder._last_Ew = torch.stack(_last_Ew)
# get the choiced w
w = torch.tensor(w, device=u.device)
w_probs = torch.stack(w_probs)
w_dist = Categorical(w_probs)
return u, w, s, u_dist, w_dist, s_dist
v = torch.stack(v)
h = torch.stack(h)
return v, h
def learn(self, rewards):
"""
@brief update the parameters of actor and critic
@param rewards: reward of an episode in a batch. [#num_batch]
@note the env has returned 'done' signal
"""
pass
......@@ -177,8 +177,6 @@ class EPTM(nn.Module):
super(EPTM, self).__init__()
self.PTM_0 = PTM(dim_e=dim_e, dim_q=dim_q)
self.PTM_1 = PTM(dim_e=dim_e, dim_q=dim_q)
# to get s
self.q_encode = nn.Linear(dim_q + dim_e, 2, bias=False)
self.C = 10.0
def forward(self, e, q, mask=None):
......@@ -187,8 +185,10 @@ class EPTM(nn.Module):
@param q: [#num_batch, 360]
@param mask: [#num_batch, #num_nodes], with masked points set to True and the reset are False
set unvisited points True
@return [#num_batch, 2, #num_nodes]
@note result[i, 0] is for batch i and s=0
@return [#num_batch, 2 * #num_nodes]
@note result[i, j]: in i-th sample the batch,
if j < #num_nodes: the prob of s = 0, and w = j
else: the prob of s = 0, and w = j - #num_nodes
"""
_, l_0 = self.PTM_0(e, q, need_l=True) # [#num_batch, #num_nodes]
_, l_1 = self.PTM_1(e, q, need_l=True) # [#num_batch, #num_nodes]
......@@ -198,13 +198,9 @@ class EPTM(nn.Module):
mask = torch.stack([mask, mask], dim=1)
l = torch.where(mask == False, l, -torch.inf)
l = torch.tanh(l).mul_(self.C)
p4w = torch.softmax(l, dim=-1) # probability of w, [#num_batch, 2, #num_nodes]
e_mean = torch.mean(e, dim=1).squeeze()
s = torch.cat([e_mean, q], dim=1)
s = self.q_encode(s)
p4s = torch.softmax(s ,dim=-1)
return p4w, p4s
batch_size = l.shape[0]
p4ws = torch.softmax(l.reshape(batch_size, -1), dim=-1)
return p4ws
class Decoder(nn.Module):
def __init__(self, dim_e=D, dim_q=360) -> None:
......@@ -253,6 +249,6 @@ class Decoder(nn.Module):
u, _ = self.ptm(e, cur_q4u, mask=mask_visited)
cur_q4w = self.q_gen(self._last_edge, self._last_subtree, u)
w, s = self.eptm(e, cur_q4w, mask=mask_unvisited)
ws = self.eptm(e, cur_q4w, mask=mask_unvisited)
return u, w, s
return u, ws
from turtle import up
import torch
import pdb
......@@ -24,26 +23,23 @@ model = Policy(actor=actor_net,
action_space=env.action_space
)
u0, _ = model.first_act(nodes)
u_list = [u0]
w_list = []
s_list = []
u0 = model.first_act(nodes)
v_list = [u0]
h_list = []
mask_visited = torch.zeros(batch_size, num_nodes).bool()
mask_visited = update_mask(mask_visited, batch_size, u0)
for i in range(1, num_nodes):
u, w, s, u_dist, w_dist, s_dist = model.act(nodes, mask_visited)
mask_visited = update_mask(mask_visited, batch_size, u)
mask_visited = update_mask(mask_visited, batch_size, w)
u_list.append(u)
w_list.append(w)
s_list.append(s)
v, h = model.act(nodes, mask_visited)
mask_visited = update_mask(mask_visited, batch_size, v)
mask_visited = update_mask(mask_visited, batch_size, h)
v_list.append(v)
h_list.append(h)
# transpose into [#num_batch, #num_nodes]
all_u = torch.stack(u_list).transpose(1, 0)
all_w = torch.stack(w_list).transpose(1, 0)
all_s = torch.stack(s_list).transpose(1, 0)
all_v = torch.stack(v_list).transpose(1, 0)
all_h = torch.stack(h_list).transpose(1, 0)
pdb.set_trace()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment