Commit 592f1b41 by lvzhengyang

add act to agent, and built the test code

parent 3dc620ab
......@@ -4,29 +4,46 @@
@date: 2022.9.1
"""
from model import Encoder, Decoder, D
from model import PTM, Encoder, Decoder, D
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from env.env import RSMTEnv
import pdb
# TODO: finish the mask
# leave the mask provided by the environment
# TODO: make the vectorized version
class Actor(nn.Module):
def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__()
super(Actor, self).__init__()
self.encoder = Encoder(dim=dim_e, N=3)
self.decoder = Decoder(dim_e, dim_q)
self.dim_e = dim_e
self.dim_q = dim_q
self._node_e = None # node embeddings
def forward(self, nodes, last_u, last_w, last_v, last_h,
last_subtree, mask_visited, mask_unvisited):
def get_u0_probs(self, nodes):
"""
@brief get a probability distribution of u0
@param nodes: [#batch_size, #num_nodes, 2]
@return probs [#batch_size, #num_nodes]
"""
e = self.encoder(nodes)
u, w, s = self.decoder(e, last_u, last_w, last_v, last_h,
last_subtree, mask_visited, mask_unvisited)
return u, w, s
self._node_e = e.clone()
u0_probs = self.decoder._get_u0_probs(e)
return u0_probs
def forward(self, nodes, mask_visited=None, mask_unvisited=None):
e = self.encoder(nodes)
u_probs, w_probs, s_probs = self.decoder(e, mask_visited, mask_unvisited)
return u_probs, w_probs, s_probs
class Critic(nn.Module):
def __init__(self, dim_e=D, dim_c=256) -> None:
super().__init__()
super(Critic, self).__init__()
self.dim_e = dim_e
self.dim_c = dim_c
self.encoder = Encoder(dim=dim_e, N=3)
......@@ -48,3 +65,89 @@ class Critic(nn.Module):
b = self.final(glimpse).squeeze() # [#batch_size]
return b
"""
the action_space is of gym.spaces.Discrete
the agent outputs probabilities, use torch.distributions.categorical.Categorical()
"""
class Policy():
def __init__(self,
actor,
critic,
obs_space,
action_space,
) -> None:
self.obs_space = obs_space
self.action_space = action_space
self.actor = actor
self.critic = critic
def first_act(self, nodes):
"""
@brief perform action for t == 0, with the query = 0, get u0
@param nodes: [#num_batch, #num_nodes, 2]
@return points: [#num_batch], dist of probs [#num_batch, #num_nodes]
"""
probs = self.actor.get_u0_probs(nodes) # [#num_batch, #num_nodes]
dist = Categorical(probs)
u0 = dist.sample()
batch_size = u0.shape[0]
_last_Eu = []
for i in range(batch_size):
_last_Eu.append(self.actor._node_e[i, u0[i]])
self.actor.decoder._last_Eu = torch.stack(_last_Eu)
self.actor.decoder._last_Ev = self.actor.decoder._last_Eu.clone()
return u0, dist
def act(self, nodes, mask_visited=None, mask_unvisited=None):
"""
@brief perform action for t >= 1
@param nodes: [#num_batch, #num_nodes, 2]
@param mask_visited/mask_unvisited: [#num_batch, #num_nodes]
TODO: gather the input into one obs: contains a batch of obs
"""
if mask_visited == None and not mask_unvisited == None:
mask_visited = ~mask_unvisited
if mask_unvisited == None and not mask_visited == None:
mask_unvisited = ~mask_visited
u_probs, _w_probs, s_probs = self.actor(nodes,
mask_visited=mask_visited, mask_unvisited=mask_unvisited)
u_dist = Categorical(u_probs)
_w_dist = Categorical(_w_probs) # wait to be choice by s
s_dist = Categorical(s_probs)
u = u_dist.sample()
_w = _w_dist.sample()
s = s_dist.sample()
batch_size = u.shape[0]
_last_Eu = []
for i in range(batch_size):
_last_Eu.append(self.actor._node_e[i, u[i]])
self.actor.decoder._last_Eu = torch.stack(_last_Eu)
_last_Ew = []
_last_Ev = []
_last_Eh = []
w = []
w_probs = []
for i in range(batch_size):
if s[i] == 0:
_last_Ev.append(self.actor._node_e[i, u[i]])
_last_Eh.append(self.actor._node_e[i, _w[i][0]])
_last_Ew.append(self.actor._node_e[i, _w[i][0]])
w.append(_w[i, 0])
w_probs.append(_w_probs[i, 0])
else:
_last_Ev.append(self.actor._node_e[i, _w[i][1]])
_last_Eh.append(self.actor._node_e[i, u[i]])
_last_Ew.append(self.actor._node_e[i, _w[i][1]])
w.append(_w[i, 1])
w_probs.append(_w_probs[i, 0])
self.actor.decoder._last_Ev = torch.stack(_last_Ev)
self.actor.decoder._last_Eh = torch.stack(_last_Eh)
self.actor.decoder._last_Ew = torch.stack(_last_Ew)
# get the choiced w
w = torch.tensor(w, device=u.device)
w_probs = torch.stack(w_probs)
w_dist = Categorical(w_probs)
return u, w, s, u_dist, w_dist, s_dist
......@@ -12,7 +12,7 @@ D = 128
class EncoderLayer(nn.Module):
def __init__(self, num_heads=16, dim_per_head=16, dim=D, h_dim=512) -> None:
super().__init__()
super(EncoderLayer, self).__init__()
self.num_heads = num_heads
self.dim_per_head = dim_per_head # seem no need
self.dim = dim
......@@ -54,7 +54,7 @@ class EncoderLayer(nn.Module):
class Encoder(nn.Module):
def __init__(self, dim=D, N=3) -> None:
super().__init__()
super(Encoder, self).__init__()
self.dim = dim
self.N = N
self.W_emb = nn.Linear(2, self.dim, bias=False) # get E_0
......@@ -79,7 +79,7 @@ class Encoder(nn.Module):
class PTM(nn.Module):
# PTM: PoinTing Mechanism
def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__()
super(PTM, self).__init__()
self.dim_e = dim_e
self.dim_q = dim_q
self.W_e = nn.Linear(dim_e, dim_q, bias=False)
......@@ -112,17 +112,17 @@ class PTM(nn.Module):
e = torch.where(mask == False, e, -torch.inf)
e = torch.tanh(e)
e.mul_(self.C)
e = nn.functional.softmax(e, dim=1)
p = nn.functional.softmax(e, dim=1)
if need_l:
return e, l
return e, None
return p, l
return p, None
class EdgeGen(nn.Module):
"""
@brief Generate query for t-th action
"""
def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__()
super(EdgeGen, self).__init__()
self.dim_e = dim_e
self.dim_q = dim_q
self.W_u = nn.Linear(self.dim_e, self.dim_q, bias=False)
......@@ -130,41 +130,55 @@ class EdgeGen(nn.Module):
self.W_v = nn.Linear(self.dim_e, self.dim_q, bias=False)
self.W_h = nn.Linear(self.dim_e, self.dim_q, bias=False)
def forward(self, u, w, v, h, last_subtree):
def forward(self, u, w, v, h):
"""
@param u, w, v, h: [#num_batch, D]
"""
edge = self.W_u(u) + self.W_w(w) + self.W_v(v) + self.W_h(h)
# edge = self.W_u(u) + self.W_w(w) + self.W_v(v) + self.W_h(h)
"""
in t = 1, last_u == last_v = u0, last_w = None, last_h = None
"""
edge = self.W_u(u) + self.W_v(v)
if not w == None:
edge += self.W_w(w)
if not h == None:
edge += self.W_h(h)
return edge
class SubTreeGen(nn.Module):
def __init__(self, dim_q=360) -> None:
super().__init__()
super(SubTreeGen, self).__init__()
self.dim_q = dim_q
self.W_edge = nn.Linear(self.dim_q, self.dim_q, bias=False)
def forward(self, last_subtree, cur_edge):
if torch.is_tensor(last_subtree):
cur_subtree = torch.max(last_subtree, self.W_edge(cur_edge))
else:
# last_subtree is set to 0 at first iteration
cur_subtree = torch.relu(self.W_edge(cur_edge))
return cur_subtree
class QGen(nn.Module):
def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__()
super(QGen, self).__init__()
self.W_5 = nn.Linear(dim_e, dim_q, bias=False)
def forward(self, last_edge, last_subtree, cur_u=None):
if cur_u == None:
cur_q = torch.max(0.0, last_edge + last_subtree)
cur_q = torch.relu(last_edge + last_subtree)
else:
cur_q = torch.max(0.0, last_edge + last_subtree + self.W_5(cur_u))
cur_q = torch.relu(last_edge + last_subtree)
return cur_q
class EPTM(nn.Module):
# EPTM: Extended PoinTing Mechanism
def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__()
super(EPTM, self).__init__()
self.PTM_0 = PTM(dim_e=dim_e, dim_q=dim_q)
self.PTM_1 = PTM(dim_e=dim_e, dim_q=dim_q)
# to get s
self.q_encode = nn.Linear(dim_q + dim_e, 2, bias=False)
self.C = 10.0
def forward(self, e, q, mask=None):
......@@ -177,7 +191,7 @@ class EPTM(nn.Module):
@note result[i, 0] is for batch i and s=0
"""
_, l_0 = self.PTM_0(e, q, need_l=True) # [#num_batch, #num_nodes]
_, l_1 = self.PTM_0(e, q, need_l=True) # [#num_batch, #num_nodes]
_, l_1 = self.PTM_1(e, q, need_l=True) # [#num_batch, #num_nodes]
l = torch.stack([l_0, l_1], dim=1) # [#num_batch, 2, #num_nodes]
if mask is not None:
if mask.dim() == 2:
......@@ -185,13 +199,16 @@ class EPTM(nn.Module):
l = torch.where(mask == False, l, -torch.inf)
l = torch.tanh(l).mul_(self.C)
p4w = torch.softmax(l, dim=-1) # probability of w, [#num_batch, 2, #num_nodes]
p4s = torch.softmax(l.transpose(1, -1), dim=-1) # probabiliti of s, [#num_batch, #num_nodes, 2]
# p4s[i, j, 0]: in batch i, if node j is selected as w, the probability for s = 0
e_mean = torch.mean(e, dim=1).squeeze()
s = torch.cat([e_mean, q], dim=1)
s = self.q_encode(s)
p4s = torch.softmax(s ,dim=-1)
return p4w, p4s
class Decoder(nn.Module):
def __init__(self, dim_e=D, dim_q=360) -> None:
super().__init__()
super(Decoder, self).__init__()
self.dim_e = dim_e
self.dim_q = dim_q
self.edge_gen = EdgeGen(dim_e, dim_q)
......@@ -199,10 +216,27 @@ class Decoder(nn.Module):
self.ptm = PTM(dim_e, dim_q)
self.eptm = EPTM(dim_e, dim_q)
self.q_gen = QGen(dim_e, dim_q)
# TODO: finish the mask
def forward(self, e, last_u, last_w, last_v, last_h,
last_subtree, mask_visited=None, mask_unvisited=None):
self.ptm_0 = PTM(dim_e=dim_e, dim_q=dim_q) # for generating u0
self._last_subtree = 0 # should be a vector, but set to 0 at first
self._last_edge = None
self._last_Eu = None
self._last_Ew = None
self._last_Ev = None
self._last_Eh = None
def _get_u0_probs(self, e):
"""
@brief get a probability distribution of u0
@param e: [#batch_size, #num_nodes, self.dim_e]
@return probs [#batch_size, #num_nodes]
"""
q = torch.zeros(e.shape[0], self.dim_q, device=e.device)
start_node_probs, _ = self.ptm_0(e, q)
return start_node_probs
def forward(self, e, mask_visited=None, mask_unvisited=None):
"""
@param e: input embeddings for all nodes, [#num_batch, #num_nodes, D]
@param last_u/w/v/h: embeddings for last u/w/v/h
......@@ -212,9 +246,13 @@ class Decoder(nn.Module):
u, w: [#num_batch, #num_nodes]
s: [#num_batch, #num_nodes, 2]
"""
last_edge = self.edge_gen(last_u, last_w, last_v, last_h)
cur_q4u = self.q_gen(last_edge, last_subtree)
self._last_edge = self.edge_gen(self._last_Eu, self._last_Ew, self._last_Ev, self._last_Eh)
self._last_subtree = self.subtree_gen(self._last_subtree, self._last_edge)
cur_q4u = self.q_gen(self._last_edge, self._last_subtree)
u, _ = self.ptm(e, cur_q4u, mask=mask_visited)
cur_q4w = self.q_gen(last_edge, last_subtree, u)
cur_q4w = self.q_gen(self._last_edge, self._last_subtree, u)
w, s = self.eptm(e, cur_q4w, mask=mask_unvisited)
return u, w, s
from turtle import up
import torch
import pdb
from model import Encoder, PTM, EPTM, Critic
from agent import Actor, Critic, Policy
from env.env import RSMTEnv
encoder = Encoder()
def update_mask(mask, batch_size, node_visited):
for i in range(batch_size):
mask[i][node_visited[i]] = True
return mask
batch_size = 4
num_nodes = 8
x = torch.randn(batch_size, num_nodes, 2)
e = encoder(x)
ptm_0 = PTM() # for starting point
q = torch.zeros(4 * 360).reshape(4, 360) # query for starting point
mask = torch.randint(0, 2, (batch_size, num_nodes)).bool()
output, _ = ptm_0(e, q, mask)
start_points = output.max(dim=1)
eptm = EPTM()
p_1 = eptm(e, q, mask)
critic = Critic()
critic(x)
nodes = torch.randn(batch_size, num_nodes, 2)
env = RSMTEnv(num_nodes=num_nodes, pos_l=0, pos_h=100)
actor_net = Actor()
critic_net = Critic()
model = Policy(actor=actor_net,
critic=critic_net,
obs_space=env.observation_space,
action_space=env.action_space
)
u0, _ = model.first_act(nodes)
u_list = [u0]
w_list = []
s_list = []
mask_visited = torch.zeros(batch_size, num_nodes).bool()
mask_visited = update_mask(mask_visited, batch_size, u0)
for i in range(1, num_nodes):
u, w, s, u_dist, w_dist, s_dist = model.act(nodes, mask_visited)
mask_visited = update_mask(mask_visited, batch_size, u)
mask_visited = update_mask(mask_visited, batch_size, w)
u_list.append(u)
w_list.append(w)
s_list.append(s)
# transpose into [#num_batch, #num_nodes]
all_u = torch.stack(u_list).transpose(1, 0)
all_w = torch.stack(w_list).transpose(1, 0)
all_s = torch.stack(s_list).transpose(1, 0)
pdb.set_trace()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment