Commit c01b9271 by lvzhengyang

fix bug in getting ws_probs

parent 35ab999f
...@@ -36,10 +36,16 @@ class Actor(nn.Module): ...@@ -36,10 +36,16 @@ class Actor(nn.Module):
u0_probs = self.decoder._get_u0_probs(e) u0_probs = self.decoder._get_u0_probs(e)
return u0_probs return u0_probs
def forward(self, nodes, mask_visited=None, mask_unvisited=None): def get_u_probs(self, nodes, mask_visited=None, mask_unvisited=None):
e = self.encoder(nodes) e = self.encoder(nodes)
u_probs, ws_probs = self.decoder(e, mask_visited, mask_unvisited) u_probs = self.decoder.get_u_probs(e, mask_visited, mask_unvisited)
return u_probs, ws_probs return u_probs
def get_ws_probs(self, nodes, mask_visited=None, mask_unvisited=None):
e = self.encoder(nodes)
ws_probs = self.decoder.get_ws_probs(e, self.decoder._last_Eu,
mask_visited, mask_unvisited)
return ws_probs
class Critic(nn.Module): class Critic(nn.Module):
def __init__(self, dim_e=D, dim_c=256) -> None: def __init__(self, dim_e=D, dim_c=256) -> None:
...@@ -137,11 +143,20 @@ class Policy(nn.Module): ...@@ -137,11 +143,20 @@ class Policy(nn.Module):
mask_unvisited = ~mask_visited mask_unvisited = ~mask_visited
num_nodes = nodes.shape[1] num_nodes = nodes.shape[1]
u_probs, ws_probs = self.actor(nodes, u_probs = self.actor.get_u_probs(nodes,
mask_visited=mask_visited, mask_unvisited=mask_unvisited) mask_visited=mask_visited, mask_unvisited=mask_unvisited)
u_dist = Categorical(u_probs) u_dist = Categorical(u_probs)
ws_dist = Categorical(ws_probs)
u = u_dist.sample() u = u_dist.sample()
batch_size = u.shape[0]
_last_Eu = []
for i in range(batch_size):
_last_Eu.append(self.actor._node_e[i, u[i]])
self.actor.decoder._last_Eu = torch.stack(_last_Eu)
ws_probs = self.actor.get_ws_probs(nodes,
mask_visited=mask_visited, mask_unvisited=mask_unvisited)
ws_dist = Categorical(ws_probs)
_w = ws_dist.sample() _w = ws_dist.sample()
self.log_probs_buf['u'].append(u_dist.log_prob(u)) self.log_probs_buf['u'].append(u_dist.log_prob(u))
...@@ -150,15 +165,12 @@ class Policy(nn.Module): ...@@ -150,15 +165,12 @@ class Policy(nn.Module):
s = torch.where(_w < num_nodes, int(0), int(1)).to(_w.device) s = torch.where(_w < num_nodes, int(0), int(1)).to(_w.device)
w = torch.where(_w < num_nodes, _w, _w - num_nodes) w = torch.where(_w < num_nodes, _w, _w - num_nodes)
batch_size = u.shape[0]
_last_Eu = []
_last_Ew = [] _last_Ew = []
_last_Ev = [] _last_Ev = []
_last_Eh = [] _last_Eh = []
v = [] v = []
h = [] h = []
for i in range(batch_size): for i in range(batch_size):
_last_Eu.append(self.actor._node_e[i, u[i]])
_last_Ew.append(self.actor._node_e[i, w[i]]) _last_Ew.append(self.actor._node_e[i, w[i]])
if s[i] == 0: if s[i] == 0:
_last_Ev.append(self.actor._node_e[i, u[i]]) _last_Ev.append(self.actor._node_e[i, u[i]])
...@@ -170,7 +182,6 @@ class Policy(nn.Module): ...@@ -170,7 +182,6 @@ class Policy(nn.Module):
_last_Eh.append(self.actor._node_e[i, u[i]]) _last_Eh.append(self.actor._node_e[i, u[i]])
v.append(w[i]) v.append(w[i])
h.append(u[i]) h.append(u[i])
self.actor.decoder._last_Eu = torch.stack(_last_Eu)
self.actor.decoder._last_Ew = torch.stack(_last_Ew) self.actor.decoder._last_Ew = torch.stack(_last_Ew)
self.actor.decoder._last_Ev = torch.stack(_last_Ev) self.actor.decoder._last_Ev = torch.stack(_last_Ev)
self.actor.decoder._last_Eh = torch.stack(_last_Eh) self.actor.decoder._last_Eh = torch.stack(_last_Eh)
......
...@@ -182,7 +182,9 @@ class QGen(nn.Module): ...@@ -182,7 +182,9 @@ class QGen(nn.Module):
if cur_u == None: if cur_u == None:
cur_q = torch.relu(last_edge + last_subtree) cur_q = torch.relu(last_edge + last_subtree)
else: else:
cur_q = torch.relu(last_edge + last_subtree) # ERROR! where is u?
tmp = self.W_5(cur_u)
cur_q = torch.relu(last_edge + last_subtree + tmp)
return cur_q return cur_q
class EPTM(nn.Module): class EPTM(nn.Module):
...@@ -246,13 +248,11 @@ class Decoder(nn.Module): ...@@ -246,13 +248,11 @@ class Decoder(nn.Module):
start_node_probs, _ = self.ptm_0(e, q) start_node_probs, _ = self.ptm_0(e, q)
return start_node_probs return start_node_probs
def forward(self, e, mask_visited=None, mask_unvisited=None): def get_u_probs(self, e, mask_visited=None, mask_unvisited=None):
""" """
@param e: input embeddings for all nodes, [#num_batch, #num_nodes, D] @param e: input embeddings for all nodes, [#num_batch, #num_nodes, D]
@param last_u/w/v/h: embeddings for last u/w/v/h
@param last_subtree: embeddings for last subtree
@note define subtree(t=0) = 0 @note define subtree(t=0) = 0
@return probability of u, w, s @return probability of u, ws
u, w: [#num_batch, #num_nodes] u, w: [#num_batch, #num_nodes]
s: [#num_batch, #num_nodes, 2] s: [#num_batch, #num_nodes, 2]
""" """
...@@ -260,9 +260,14 @@ class Decoder(nn.Module): ...@@ -260,9 +260,14 @@ class Decoder(nn.Module):
self._last_subtree = self.subtree_gen(self._last_subtree, self._last_edge) self._last_subtree = self.subtree_gen(self._last_subtree, self._last_edge)
cur_q4u = self.q_gen(self._last_edge, self._last_subtree) cur_q4u = self.q_gen(self._last_edge, self._last_subtree)
u, _ = self.ptm(e, cur_q4u, mask=mask_visited) u_probs, _ = self.ptm(e, cur_q4u, mask=mask_visited)
return u_probs
cur_q4w = self.q_gen(self._last_edge, self._last_subtree, u) def get_ws_probs(self, e, E_u, mask_visited=None, mask_unvisited=None):
ws = self.eptm(e, cur_q4w, mask=mask_unvisited) """
@param u: is the node index choiced
"""
cur_q4w = self.q_gen(self._last_edge, self._last_subtree, E_u)
ws_probs = self.eptm(e, cur_q4w, mask=mask_unvisited)
return u, ws return ws_probs
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment