Commit c01b9271 by lvzhengyang

fix bug in getting ws_probs

parent 35ab999f
......@@ -36,10 +36,16 @@ class Actor(nn.Module):
u0_probs = self.decoder._get_u0_probs(e)
return u0_probs
def forward(self, nodes, mask_visited=None, mask_unvisited=None):
def get_u_probs(self, nodes, mask_visited=None, mask_unvisited=None):
e = self.encoder(nodes)
u_probs, ws_probs = self.decoder(e, mask_visited, mask_unvisited)
return u_probs, ws_probs
u_probs = self.decoder.get_u_probs(e, mask_visited, mask_unvisited)
return u_probs
def get_ws_probs(self, nodes, mask_visited=None, mask_unvisited=None):
e = self.encoder(nodes)
ws_probs = self.decoder.get_ws_probs(e, self.decoder._last_Eu,
mask_visited, mask_unvisited)
return ws_probs
class Critic(nn.Module):
def __init__(self, dim_e=D, dim_c=256) -> None:
......@@ -137,11 +143,20 @@ class Policy(nn.Module):
mask_unvisited = ~mask_visited
num_nodes = nodes.shape[1]
u_probs, ws_probs = self.actor(nodes,
u_probs = self.actor.get_u_probs(nodes,
mask_visited=mask_visited, mask_unvisited=mask_unvisited)
u_dist = Categorical(u_probs)
ws_dist = Categorical(ws_probs)
u = u_dist.sample()
batch_size = u.shape[0]
_last_Eu = []
for i in range(batch_size):
_last_Eu.append(self.actor._node_e[i, u[i]])
self.actor.decoder._last_Eu = torch.stack(_last_Eu)
ws_probs = self.actor.get_ws_probs(nodes,
mask_visited=mask_visited, mask_unvisited=mask_unvisited)
ws_dist = Categorical(ws_probs)
_w = ws_dist.sample()
self.log_probs_buf['u'].append(u_dist.log_prob(u))
......@@ -150,15 +165,12 @@ class Policy(nn.Module):
s = torch.where(_w < num_nodes, int(0), int(1)).to(_w.device)
w = torch.where(_w < num_nodes, _w, _w - num_nodes)
batch_size = u.shape[0]
_last_Eu = []
_last_Ew = []
_last_Ev = []
_last_Eh = []
v = []
h = []
for i in range(batch_size):
_last_Eu.append(self.actor._node_e[i, u[i]])
_last_Ew.append(self.actor._node_e[i, w[i]])
if s[i] == 0:
_last_Ev.append(self.actor._node_e[i, u[i]])
......@@ -170,7 +182,6 @@ class Policy(nn.Module):
_last_Eh.append(self.actor._node_e[i, u[i]])
v.append(w[i])
h.append(u[i])
self.actor.decoder._last_Eu = torch.stack(_last_Eu)
self.actor.decoder._last_Ew = torch.stack(_last_Ew)
self.actor.decoder._last_Ev = torch.stack(_last_Ev)
self.actor.decoder._last_Eh = torch.stack(_last_Eh)
......
......@@ -182,7 +182,9 @@ class QGen(nn.Module):
if cur_u == None:
cur_q = torch.relu(last_edge + last_subtree)
else:
cur_q = torch.relu(last_edge + last_subtree)
# ERROR! where is u?
tmp = self.W_5(cur_u)
cur_q = torch.relu(last_edge + last_subtree + tmp)
return cur_q
class EPTM(nn.Module):
......@@ -246,13 +248,11 @@ class Decoder(nn.Module):
start_node_probs, _ = self.ptm_0(e, q)
return start_node_probs
def forward(self, e, mask_visited=None, mask_unvisited=None):
def get_u_probs(self, e, mask_visited=None, mask_unvisited=None):
"""
@param e: input embeddings for all nodes, [#num_batch, #num_nodes, D]
@param last_u/w/v/h: embeddings for last u/w/v/h
@param last_subtree: embeddings for last subtree
@note define subtree(t=0) = 0
@return probability of u, w, s
@return probability of u, ws
u, w: [#num_batch, #num_nodes]
s: [#num_batch, #num_nodes, 2]
"""
......@@ -260,9 +260,14 @@ class Decoder(nn.Module):
self._last_subtree = self.subtree_gen(self._last_subtree, self._last_edge)
cur_q4u = self.q_gen(self._last_edge, self._last_subtree)
u, _ = self.ptm(e, cur_q4u, mask=mask_visited)
u_probs, _ = self.ptm(e, cur_q4u, mask=mask_visited)
return u_probs
cur_q4w = self.q_gen(self._last_edge, self._last_subtree, u)
ws = self.eptm(e, cur_q4w, mask=mask_unvisited)
def get_ws_probs(self, e, E_u, mask_visited=None, mask_unvisited=None):
"""
@param u: is the node index choiced
"""
cur_q4w = self.q_gen(self._last_edge, self._last_subtree, E_u)
ws_probs = self.eptm(e, cur_q4w, mask=mask_unvisited)
return u, ws
return ws_probs
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment