Commit 7f1d8bef by lvzhengyang

add "learn" part for Policy

parent 131c30fa
...@@ -60,7 +60,7 @@ class Critic(nn.Module): ...@@ -60,7 +60,7 @@ class Critic(nn.Module):
@return Expection for each batch, [#num_batch] @return Expection for each batch, [#num_batch]
""" """
e = self.encoder(nodes) # [#batch_size, #num_nodes, D] e = self.encoder(nodes) # [#batch_size, #num_nodes, D]
glimpse_w = self.g(torch.tanh(e)).squeeze_() # [#batch_size, #num_nodes] glimpse_w = self.g(torch.tanh(e)).squeeze() # [#batch_size, #num_nodes]
glimpse_w = torch.softmax(glimpse_w, dim=1).unsqueeze(1) # [#batch_size, 1, #num_nodes] glimpse_w = torch.softmax(glimpse_w, dim=1).unsqueeze(1) # [#batch_size, 1, #num_nodes]
glimpse = torch.bmm(glimpse_w, e).squeeze() # [#batch_size, D] glimpse = torch.bmm(glimpse_w, e).squeeze() # [#batch_size, D]
b = self.final(glimpse).squeeze() # [#batch_size] b = self.final(glimpse).squeeze() # [#batch_size]
...@@ -70,17 +70,26 @@ class Critic(nn.Module): ...@@ -70,17 +70,26 @@ class Critic(nn.Module):
the action_space is of gym.spaces.Discrete the action_space is of gym.spaces.Discrete
the agent outputs probabilities, use torch.distributions.categorical.Categorical() the agent outputs probabilities, use torch.distributions.categorical.Categorical()
""" """
class Policy(): class Policy(nn.Module):
def __init__(self, def __init__(self,
actor, actor,
critic, critic,
obs_space, obs_space,
action_space, action_space,
optimizer=None,
lr=2.5e-4,
weight_decay=5e-4,
loss_fn_critic=torch.nn.MSELoss()
) -> None: ) -> None:
self.obs_space = obs_space super(Policy, self).__init__()
self.action_space = action_space
self.actor = actor self.actor = actor
self.critic = critic self.critic = critic
self.obs_space = obs_space
self.action_space = action_space
if optimizer == None:
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
self.loss_fn_critic = loss_fn_critic
self.log_probs_buf = { self.log_probs_buf = {
"u0": None, "u0": None,
...@@ -88,6 +97,9 @@ class Policy(): ...@@ -88,6 +97,9 @@ class Policy():
"ws": [] "ws": []
} }
def forward(self, x):
raise NotImplementedError
def first_act(self, nodes): def first_act(self, nodes):
""" """
@brief perform action for t == 0, with the query = 0, get u0 @brief perform action for t == 0, with the query = 0, get u0
...@@ -162,11 +174,39 @@ class Policy(): ...@@ -162,11 +174,39 @@ class Policy():
h = torch.stack(h) h = torch.stack(h)
return v, h return v, h
def learn(self, rewards): def learn(self, nodes, rewards):
""" """
@brief update the parameters of actor and critic @brief update the parameters of actor and critic via REINFORECE algo.
@param nodes: [#batch, #num_nodes, 2] in dtype float
@param rewards: reward of an episode in a batch. [#num_batch] @param rewards: reward of an episode in a batch. [#num_batch]
@note the env has returned 'done' signal @note the env has returned 'done' signal
""" """
pass self.optimizer.zero_grad()
# cal log(p_res) for each trajectory in the batch
u0 = self.log_probs_buf['u0'] # [#num_batch]
u = torch.stack(self.log_probs_buf['u']).transpose(1, 0) # [#num_batch, #num_nodes - 1]
ws = torch.stack(self.log_probs_buf['ws']).transpose(1, 0) # [#num_batch, #num_nodes - 1]
# p_res: [#num_batch]
p_res = torch.add(u, ws).sum(dim=-1).add(u0)
with torch.no_grad():
baselines = self.critic(nodes)
j = (baselines - rewards) * p_res
j = j.mean()
baselines = self.critic(nodes)
loss_critic = self.loss_fn_critic(baselines, rewards)
j.backward()
loss_critic.backward()
self.optimizer.step()
# Finally, reset the buf
self.log_probs_buf = {
"u0": None,
"u": [],
"ws": []
}
...@@ -45,11 +45,18 @@ class EncoderLayer(nn.Module): ...@@ -45,11 +45,18 @@ class EncoderLayer(nn.Module):
k = self.key(x) k = self.key(x)
v = self.value(x) v = self.value(x)
x1, _ = self.attention(query=q, key=k, value=v, need_weights=False) x1, _ = self.attention(query=q, key=k, value=v, need_weights=False)
x.add_(x1) x = x.add(x1)
x = self.norm(x.view(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1) # x = self.norm(x.reshape(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1)
x = x.transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2)
x1 = self.feed_foward(x) x1 = self.feed_foward(x)
x.add_(x1) x = x.add(x1)
x = self.norm(x.view(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1) # x = self.norm(x.reshape(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1)
x = x.transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2)
return x return x
class Encoder(nn.Module): class Encoder(nn.Module):
...@@ -58,7 +65,6 @@ class Encoder(nn.Module): ...@@ -58,7 +65,6 @@ class Encoder(nn.Module):
self.dim = dim self.dim = dim
self.N = N self.N = N
self.W_emb = nn.Linear(2, self.dim, bias=False) # get E_0 self.W_emb = nn.Linear(2, self.dim, bias=False) # get E_0
# TODO: implement batch norm
self.encoder_layers = nn.Sequential() self.encoder_layers = nn.Sequential()
for i in range(self.N): for i in range(self.N):
self.encoder_layers.add_module('{}_{}'.format(EncoderLayer.__name__, i), self.encoder_layers.add_module('{}_{}'.format(EncoderLayer.__name__, i),
...@@ -72,7 +78,15 @@ class Encoder(nn.Module): ...@@ -72,7 +78,15 @@ class Encoder(nn.Module):
batch_size = x.shape[0] batch_size = x.shape[0]
num_nodes = x.shape[1] num_nodes = x.shape[1]
x = self.W_emb(x) x = self.W_emb(x)
x = self.norm(x.view(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1) x = x.transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2)
"""
dim = x.shape[-1]
x = x.view(-1, dim)
x = self.norm(x)
x = x.reshape(batch_size, num_nodes, -1)
"""
x = self.encoder_layers(x) x = self.encoder_layers(x)
return x return x
...@@ -103,15 +117,15 @@ class PTM(nn.Module): ...@@ -103,15 +117,15 @@ class PTM(nn.Module):
# e: [#num_batch, #num_nodes, 360] # e: [#num_batch, #num_nodes, 360]
# q: [#num_batch, , 360] # q: [#num_batch, , 360]
# for each batch, perform e + q (q is broadcasted in to #num_nodes) # for each batch, perform e + q (q is broadcasted in to #num_nodes)
e.transpose_(0, 1).add_(q).transpose_(0, 1) e = e.transpose_(0, 1).add(q).transpose(0, 1)
e = self.W_g(e) # get l e = self.W_g(e) # get l
e.squeeze_() e = e.squeeze()
l = e.clone() l = e.clone()
if mask != None: if mask != None:
# points be masked is set to be -INF # points be masked is set to be -INF
e = torch.where(mask == False, e, -torch.inf) e = torch.where(mask == False, e, -torch.inf)
e = torch.tanh(e) e = torch.tanh(e)
e.mul_(self.C) e = e.mul(self.C)
p = nn.functional.softmax(e, dim=1) p = nn.functional.softmax(e, dim=1)
if need_l: if need_l:
return p, l return p, l
...@@ -197,7 +211,7 @@ class EPTM(nn.Module): ...@@ -197,7 +211,7 @@ class EPTM(nn.Module):
if mask.dim() == 2: if mask.dim() == 2:
mask = torch.stack([mask, mask], dim=1) mask = torch.stack([mask, mask], dim=1)
l = torch.where(mask == False, l, -torch.inf) l = torch.where(mask == False, l, -torch.inf)
l = torch.tanh(l).mul_(self.C) l = torch.tanh(l).mul(self.C)
batch_size = l.shape[0] batch_size = l.shape[0]
p4ws = torch.softmax(l.reshape(batch_size, -1), dim=-1) p4ws = torch.softmax(l.reshape(batch_size, -1), dim=-1)
return p4ws return p4ws
......
...@@ -42,4 +42,7 @@ for i in range(1, num_nodes): ...@@ -42,4 +42,7 @@ for i in range(1, num_nodes):
all_v = torch.stack(v_list).transpose(1, 0) all_v = torch.stack(v_list).transpose(1, 0)
all_h = torch.stack(h_list).transpose(1, 0) all_h = torch.stack(h_list).transpose(1, 0)
rewards = -10 * torch.ones(batch_size)
model.learn(nodes, rewards)
pdb.set_trace() pdb.set_trace()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment