Commit 592f1b41 by lvzhengyang

add act to agent, and built the test code

parent 3dc620ab
...@@ -4,29 +4,46 @@ ...@@ -4,29 +4,46 @@
@date: 2022.9.1 @date: 2022.9.1
""" """
from model import Encoder, Decoder, D from model import PTM, Encoder, Decoder, D
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributions.categorical import Categorical
from env.env import RSMTEnv
import pdb
# TODO: finish the mask # TODO: finish the mask
# leave the mask provided by the environment # leave the mask provided by the environment
# TODO: make the vectorized version # TODO: make the vectorized version
class Actor(nn.Module): class Actor(nn.Module):
def __init__(self, dim_e=D, dim_q=360) -> None: def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__() super(Actor, self).__init__()
self.encoder = Encoder(dim=dim_e, N=3) self.encoder = Encoder(dim=dim_e, N=3)
self.decoder = Decoder(dim_e, dim_q) self.decoder = Decoder(dim_e, dim_q)
self.dim_e = dim_e
self.dim_q = dim_q
def forward(self, nodes, last_u, last_w, last_v, last_h, self._node_e = None # node embeddings
last_subtree, mask_visited, mask_unvisited):
def get_u0_probs(self, nodes):
"""
@brief get a probability distribution of u0
@param nodes: [#batch_size, #num_nodes, 2]
@return probs [#batch_size, #num_nodes]
"""
e = self.encoder(nodes)
self._node_e = e.clone()
u0_probs = self.decoder._get_u0_probs(e)
return u0_probs
def forward(self, nodes, mask_visited=None, mask_unvisited=None):
e = self.encoder(nodes) e = self.encoder(nodes)
u, w, s = self.decoder(e, last_u, last_w, last_v, last_h, u_probs, w_probs, s_probs = self.decoder(e, mask_visited, mask_unvisited)
last_subtree, mask_visited, mask_unvisited) return u_probs, w_probs, s_probs
return u, w, s
class Critic(nn.Module): class Critic(nn.Module):
def __init__(self, dim_e=D, dim_c=256) -> None: def __init__(self, dim_e=D, dim_c=256) -> None:
super().__init__() super(Critic, self).__init__()
self.dim_e = dim_e self.dim_e = dim_e
self.dim_c = dim_c self.dim_c = dim_c
self.encoder = Encoder(dim=dim_e, N=3) self.encoder = Encoder(dim=dim_e, N=3)
...@@ -48,3 +65,89 @@ class Critic(nn.Module): ...@@ -48,3 +65,89 @@ class Critic(nn.Module):
b = self.final(glimpse).squeeze() # [#batch_size] b = self.final(glimpse).squeeze() # [#batch_size]
return b return b
"""
the action_space is of gym.spaces.Discrete
the agent outputs probabilities, use torch.distributions.categorical.Categorical()
"""
class Policy():
def __init__(self,
actor,
critic,
obs_space,
action_space,
) -> None:
self.obs_space = obs_space
self.action_space = action_space
self.actor = actor
self.critic = critic
def first_act(self, nodes):
"""
@brief perform action for t == 0, with the query = 0, get u0
@param nodes: [#num_batch, #num_nodes, 2]
@return points: [#num_batch], dist of probs [#num_batch, #num_nodes]
"""
probs = self.actor.get_u0_probs(nodes) # [#num_batch, #num_nodes]
dist = Categorical(probs)
u0 = dist.sample()
batch_size = u0.shape[0]
_last_Eu = []
for i in range(batch_size):
_last_Eu.append(self.actor._node_e[i, u0[i]])
self.actor.decoder._last_Eu = torch.stack(_last_Eu)
self.actor.decoder._last_Ev = self.actor.decoder._last_Eu.clone()
return u0, dist
def act(self, nodes, mask_visited=None, mask_unvisited=None):
"""
@brief perform action for t >= 1
@param nodes: [#num_batch, #num_nodes, 2]
@param mask_visited/mask_unvisited: [#num_batch, #num_nodes]
TODO: gather the input into one obs: contains a batch of obs
"""
if mask_visited == None and not mask_unvisited == None:
mask_visited = ~mask_unvisited
if mask_unvisited == None and not mask_visited == None:
mask_unvisited = ~mask_visited
u_probs, _w_probs, s_probs = self.actor(nodes,
mask_visited=mask_visited, mask_unvisited=mask_unvisited)
u_dist = Categorical(u_probs)
_w_dist = Categorical(_w_probs) # wait to be choice by s
s_dist = Categorical(s_probs)
u = u_dist.sample()
_w = _w_dist.sample()
s = s_dist.sample()
batch_size = u.shape[0]
_last_Eu = []
for i in range(batch_size):
_last_Eu.append(self.actor._node_e[i, u[i]])
self.actor.decoder._last_Eu = torch.stack(_last_Eu)
_last_Ew = []
_last_Ev = []
_last_Eh = []
w = []
w_probs = []
for i in range(batch_size):
if s[i] == 0:
_last_Ev.append(self.actor._node_e[i, u[i]])
_last_Eh.append(self.actor._node_e[i, _w[i][0]])
_last_Ew.append(self.actor._node_e[i, _w[i][0]])
w.append(_w[i, 0])
w_probs.append(_w_probs[i, 0])
else:
_last_Ev.append(self.actor._node_e[i, _w[i][1]])
_last_Eh.append(self.actor._node_e[i, u[i]])
_last_Ew.append(self.actor._node_e[i, _w[i][1]])
w.append(_w[i, 1])
w_probs.append(_w_probs[i, 0])
self.actor.decoder._last_Ev = torch.stack(_last_Ev)
self.actor.decoder._last_Eh = torch.stack(_last_Eh)
self.actor.decoder._last_Ew = torch.stack(_last_Ew)
# get the choiced w
w = torch.tensor(w, device=u.device)
w_probs = torch.stack(w_probs)
w_dist = Categorical(w_probs)
return u, w, s, u_dist, w_dist, s_dist
...@@ -12,7 +12,7 @@ D = 128 ...@@ -12,7 +12,7 @@ D = 128
class EncoderLayer(nn.Module): class EncoderLayer(nn.Module):
def __init__(self, num_heads=16, dim_per_head=16, dim=D, h_dim=512) -> None: def __init__(self, num_heads=16, dim_per_head=16, dim=D, h_dim=512) -> None:
super().__init__() super(EncoderLayer, self).__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.dim_per_head = dim_per_head # seem no need self.dim_per_head = dim_per_head # seem no need
self.dim = dim self.dim = dim
...@@ -54,7 +54,7 @@ class EncoderLayer(nn.Module): ...@@ -54,7 +54,7 @@ class EncoderLayer(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, dim=D, N=3) -> None: def __init__(self, dim=D, N=3) -> None:
super().__init__() super(Encoder, self).__init__()
self.dim = dim self.dim = dim
self.N = N self.N = N
self.W_emb = nn.Linear(2, self.dim, bias=False) # get E_0 self.W_emb = nn.Linear(2, self.dim, bias=False) # get E_0
...@@ -79,7 +79,7 @@ class Encoder(nn.Module): ...@@ -79,7 +79,7 @@ class Encoder(nn.Module):
class PTM(nn.Module): class PTM(nn.Module):
# PTM: PoinTing Mechanism # PTM: PoinTing Mechanism
def __init__(self, dim_e=D, dim_q=360) -> None: def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__() super(PTM, self).__init__()
self.dim_e = dim_e self.dim_e = dim_e
self.dim_q = dim_q self.dim_q = dim_q
self.W_e = nn.Linear(dim_e, dim_q, bias=False) self.W_e = nn.Linear(dim_e, dim_q, bias=False)
...@@ -112,17 +112,17 @@ class PTM(nn.Module): ...@@ -112,17 +112,17 @@ class PTM(nn.Module):
e = torch.where(mask == False, e, -torch.inf) e = torch.where(mask == False, e, -torch.inf)
e = torch.tanh(e) e = torch.tanh(e)
e.mul_(self.C) e.mul_(self.C)
e = nn.functional.softmax(e, dim=1) p = nn.functional.softmax(e, dim=1)
if need_l: if need_l:
return e, l return p, l
return e, None return p, None
class EdgeGen(nn.Module): class EdgeGen(nn.Module):
""" """
@brief Generate query for t-th action @brief Generate query for t-th action
""" """
def __init__(self, dim_e=D, dim_q=360) -> None: def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__() super(EdgeGen, self).__init__()
self.dim_e = dim_e self.dim_e = dim_e
self.dim_q = dim_q self.dim_q = dim_q
self.W_u = nn.Linear(self.dim_e, self.dim_q, bias=False) self.W_u = nn.Linear(self.dim_e, self.dim_q, bias=False)
...@@ -130,41 +130,55 @@ class EdgeGen(nn.Module): ...@@ -130,41 +130,55 @@ class EdgeGen(nn.Module):
self.W_v = nn.Linear(self.dim_e, self.dim_q, bias=False) self.W_v = nn.Linear(self.dim_e, self.dim_q, bias=False)
self.W_h = nn.Linear(self.dim_e, self.dim_q, bias=False) self.W_h = nn.Linear(self.dim_e, self.dim_q, bias=False)
def forward(self, u, w, v, h, last_subtree): def forward(self, u, w, v, h):
""" """
@param u, w, v, h: [#num_batch, D] @param u, w, v, h: [#num_batch, D]
""" """
edge = self.W_u(u) + self.W_w(w) + self.W_v(v) + self.W_h(h) # edge = self.W_u(u) + self.W_w(w) + self.W_v(v) + self.W_h(h)
"""
in t = 1, last_u == last_v = u0, last_w = None, last_h = None
"""
edge = self.W_u(u) + self.W_v(v)
if not w == None:
edge += self.W_w(w)
if not h == None:
edge += self.W_h(h)
return edge return edge
class SubTreeGen(nn.Module): class SubTreeGen(nn.Module):
def __init__(self, dim_q=360) -> None: def __init__(self, dim_q=360) -> None:
super().__init__() super(SubTreeGen, self).__init__()
self.dim_q = dim_q self.dim_q = dim_q
self.W_edge = nn.Linear(self.dim_q, self.dim_q, bias=False) self.W_edge = nn.Linear(self.dim_q, self.dim_q, bias=False)
def forward(self, last_subtree, cur_edge): def forward(self, last_subtree, cur_edge):
cur_subtree = torch.max(last_subtree, self.W_edge(cur_edge)) if torch.is_tensor(last_subtree):
cur_subtree = torch.max(last_subtree, self.W_edge(cur_edge))
else:
# last_subtree is set to 0 at first iteration
cur_subtree = torch.relu(self.W_edge(cur_edge))
return cur_subtree return cur_subtree
class QGen(nn.Module): class QGen(nn.Module):
def __init__(self, dim_e=D, dim_q=360) -> None: def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__() super(QGen, self).__init__()
self.W_5 = nn.Linear(dim_e, dim_q, bias=False) self.W_5 = nn.Linear(dim_e, dim_q, bias=False)
def forward(self, last_edge, last_subtree, cur_u=None): def forward(self, last_edge, last_subtree, cur_u=None):
if cur_u == None: if cur_u == None:
cur_q = torch.max(0.0, last_edge + last_subtree) cur_q = torch.relu(last_edge + last_subtree)
else: else:
cur_q = torch.max(0.0, last_edge + last_subtree + self.W_5(cur_u)) cur_q = torch.relu(last_edge + last_subtree)
return cur_q return cur_q
class EPTM(nn.Module): class EPTM(nn.Module):
# EPTM: Extended PoinTing Mechanism # EPTM: Extended PoinTing Mechanism
def __init__(self, dim_e=D, dim_q=360) -> None: def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__() super(EPTM, self).__init__()
self.PTM_0 = PTM(dim_e=dim_e, dim_q=dim_q) self.PTM_0 = PTM(dim_e=dim_e, dim_q=dim_q)
self.PTM_1 = PTM(dim_e=dim_e, dim_q=dim_q) self.PTM_1 = PTM(dim_e=dim_e, dim_q=dim_q)
# to get s
self.q_encode = nn.Linear(dim_q + dim_e, 2, bias=False)
self.C = 10.0 self.C = 10.0
def forward(self, e, q, mask=None): def forward(self, e, q, mask=None):
...@@ -177,7 +191,7 @@ class EPTM(nn.Module): ...@@ -177,7 +191,7 @@ class EPTM(nn.Module):
@note result[i, 0] is for batch i and s=0 @note result[i, 0] is for batch i and s=0
""" """
_, l_0 = self.PTM_0(e, q, need_l=True) # [#num_batch, #num_nodes] _, l_0 = self.PTM_0(e, q, need_l=True) # [#num_batch, #num_nodes]
_, l_1 = self.PTM_0(e, q, need_l=True) # [#num_batch, #num_nodes] _, l_1 = self.PTM_1(e, q, need_l=True) # [#num_batch, #num_nodes]
l = torch.stack([l_0, l_1], dim=1) # [#num_batch, 2, #num_nodes] l = torch.stack([l_0, l_1], dim=1) # [#num_batch, 2, #num_nodes]
if mask is not None: if mask is not None:
if mask.dim() == 2: if mask.dim() == 2:
...@@ -185,13 +199,16 @@ class EPTM(nn.Module): ...@@ -185,13 +199,16 @@ class EPTM(nn.Module):
l = torch.where(mask == False, l, -torch.inf) l = torch.where(mask == False, l, -torch.inf)
l = torch.tanh(l).mul_(self.C) l = torch.tanh(l).mul_(self.C)
p4w = torch.softmax(l, dim=-1) # probability of w, [#num_batch, 2, #num_nodes] p4w = torch.softmax(l, dim=-1) # probability of w, [#num_batch, 2, #num_nodes]
p4s = torch.softmax(l.transpose(1, -1), dim=-1) # probabiliti of s, [#num_batch, #num_nodes, 2]
# p4s[i, j, 0]: in batch i, if node j is selected as w, the probability for s = 0 e_mean = torch.mean(e, dim=1).squeeze()
s = torch.cat([e_mean, q], dim=1)
s = self.q_encode(s)
p4s = torch.softmax(s ,dim=-1)
return p4w, p4s return p4w, p4s
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, dim_e=D, dim_q=360) -> None: def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__() super(Decoder, self).__init__()
self.dim_e = dim_e self.dim_e = dim_e
self.dim_q = dim_q self.dim_q = dim_q
self.edge_gen = EdgeGen(dim_e, dim_q) self.edge_gen = EdgeGen(dim_e, dim_q)
...@@ -199,10 +216,27 @@ class Decoder(nn.Module): ...@@ -199,10 +216,27 @@ class Decoder(nn.Module):
self.ptm = PTM(dim_e, dim_q) self.ptm = PTM(dim_e, dim_q)
self.eptm = EPTM(dim_e, dim_q) self.eptm = EPTM(dim_e, dim_q)
self.q_gen = QGen(dim_e, dim_q) self.q_gen = QGen(dim_e, dim_q)
# TODO: finish the mask
def forward(self, e, last_u, last_w, last_v, last_h, self.ptm_0 = PTM(dim_e=dim_e, dim_q=dim_q) # for generating u0
last_subtree, mask_visited=None, mask_unvisited=None):
self._last_subtree = 0 # should be a vector, but set to 0 at first
self._last_edge = None
self._last_Eu = None
self._last_Ew = None
self._last_Ev = None
self._last_Eh = None
def _get_u0_probs(self, e):
"""
@brief get a probability distribution of u0
@param e: [#batch_size, #num_nodes, self.dim_e]
@return probs [#batch_size, #num_nodes]
"""
q = torch.zeros(e.shape[0], self.dim_q, device=e.device)
start_node_probs, _ = self.ptm_0(e, q)
return start_node_probs
def forward(self, e, mask_visited=None, mask_unvisited=None):
""" """
@param e: input embeddings for all nodes, [#num_batch, #num_nodes, D] @param e: input embeddings for all nodes, [#num_batch, #num_nodes, D]
@param last_u/w/v/h: embeddings for last u/w/v/h @param last_u/w/v/h: embeddings for last u/w/v/h
...@@ -212,9 +246,13 @@ class Decoder(nn.Module): ...@@ -212,9 +246,13 @@ class Decoder(nn.Module):
u, w: [#num_batch, #num_nodes] u, w: [#num_batch, #num_nodes]
s: [#num_batch, #num_nodes, 2] s: [#num_batch, #num_nodes, 2]
""" """
last_edge = self.edge_gen(last_u, last_w, last_v, last_h) self._last_edge = self.edge_gen(self._last_Eu, self._last_Ew, self._last_Ev, self._last_Eh)
cur_q4u = self.q_gen(last_edge, last_subtree) self._last_subtree = self.subtree_gen(self._last_subtree, self._last_edge)
cur_q4u = self.q_gen(self._last_edge, self._last_subtree)
u, _ = self.ptm(e, cur_q4u, mask=mask_visited) u, _ = self.ptm(e, cur_q4u, mask=mask_visited)
cur_q4w = self.q_gen(last_edge, last_subtree, u)
cur_q4w = self.q_gen(self._last_edge, self._last_subtree, u)
w, s = self.eptm(e, cur_q4w, mask=mask_unvisited) w, s = self.eptm(e, cur_q4w, mask=mask_unvisited)
return u, w, s return u, w, s
from turtle import up
import torch import torch
import pdb import pdb
from model import Encoder, PTM, EPTM, Critic from agent import Actor, Critic, Policy
from env.env import RSMTEnv
encoder = Encoder() def update_mask(mask, batch_size, node_visited):
for i in range(batch_size):
mask[i][node_visited[i]] = True
return mask
batch_size = 4 batch_size = 4
num_nodes = 8 num_nodes = 8
x = torch.randn(batch_size, num_nodes, 2) nodes = torch.randn(batch_size, num_nodes, 2)
e = encoder(x)
env = RSMTEnv(num_nodes=num_nodes, pos_l=0, pos_h=100)
ptm_0 = PTM() # for starting point
q = torch.zeros(4 * 360).reshape(4, 360) # query for starting point actor_net = Actor()
mask = torch.randint(0, 2, (batch_size, num_nodes)).bool() critic_net = Critic()
output, _ = ptm_0(e, q, mask) model = Policy(actor=actor_net,
start_points = output.max(dim=1) critic=critic_net,
obs_space=env.observation_space,
eptm = EPTM() action_space=env.action_space
p_1 = eptm(e, q, mask) )
u0, _ = model.first_act(nodes)
critic = Critic() u_list = [u0]
critic(x) w_list = []
s_list = []
mask_visited = torch.zeros(batch_size, num_nodes).bool()
mask_visited = update_mask(mask_visited, batch_size, u0)
for i in range(1, num_nodes):
u, w, s, u_dist, w_dist, s_dist = model.act(nodes, mask_visited)
mask_visited = update_mask(mask_visited, batch_size, u)
mask_visited = update_mask(mask_visited, batch_size, w)
u_list.append(u)
w_list.append(w)
s_list.append(s)
# transpose into [#num_batch, #num_nodes]
all_u = torch.stack(u_list).transpose(1, 0)
all_w = torch.stack(w_list).transpose(1, 0)
all_s = torch.stack(s_list).transpose(1, 0)
pdb.set_trace() pdb.set_trace()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment