Commit 7f1d8bef by lvzhengyang

add "learn" part for Policy

parent 131c30fa
......@@ -60,7 +60,7 @@ class Critic(nn.Module):
@return Expection for each batch, [#num_batch]
"""
e = self.encoder(nodes) # [#batch_size, #num_nodes, D]
glimpse_w = self.g(torch.tanh(e)).squeeze_() # [#batch_size, #num_nodes]
glimpse_w = self.g(torch.tanh(e)).squeeze() # [#batch_size, #num_nodes]
glimpse_w = torch.softmax(glimpse_w, dim=1).unsqueeze(1) # [#batch_size, 1, #num_nodes]
glimpse = torch.bmm(glimpse_w, e).squeeze() # [#batch_size, D]
b = self.final(glimpse).squeeze() # [#batch_size]
......@@ -70,17 +70,26 @@ class Critic(nn.Module):
the action_space is of gym.spaces.Discrete
the agent outputs probabilities, use torch.distributions.categorical.Categorical()
"""
class Policy():
class Policy(nn.Module):
def __init__(self,
actor,
critic,
obs_space,
action_space,
optimizer=None,
lr=2.5e-4,
weight_decay=5e-4,
loss_fn_critic=torch.nn.MSELoss()
) -> None:
self.obs_space = obs_space
self.action_space = action_space
super(Policy, self).__init__()
self.actor = actor
self.critic = critic
self.obs_space = obs_space
self.action_space = action_space
if optimizer == None:
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
self.loss_fn_critic = loss_fn_critic
self.log_probs_buf = {
"u0": None,
......@@ -88,6 +97,9 @@ class Policy():
"ws": []
}
def forward(self, x):
raise NotImplementedError
def first_act(self, nodes):
"""
@brief perform action for t == 0, with the query = 0, get u0
......@@ -162,11 +174,39 @@ class Policy():
h = torch.stack(h)
return v, h
def learn(self, rewards):
def learn(self, nodes, rewards):
"""
@brief update the parameters of actor and critic
@brief update the parameters of actor and critic via REINFORECE algo.
@param nodes: [#batch, #num_nodes, 2] in dtype float
@param rewards: reward of an episode in a batch. [#num_batch]
@note the env has returned 'done' signal
"""
pass
self.optimizer.zero_grad()
# cal log(p_res) for each trajectory in the batch
u0 = self.log_probs_buf['u0'] # [#num_batch]
u = torch.stack(self.log_probs_buf['u']).transpose(1, 0) # [#num_batch, #num_nodes - 1]
ws = torch.stack(self.log_probs_buf['ws']).transpose(1, 0) # [#num_batch, #num_nodes - 1]
# p_res: [#num_batch]
p_res = torch.add(u, ws).sum(dim=-1).add(u0)
with torch.no_grad():
baselines = self.critic(nodes)
j = (baselines - rewards) * p_res
j = j.mean()
baselines = self.critic(nodes)
loss_critic = self.loss_fn_critic(baselines, rewards)
j.backward()
loss_critic.backward()
self.optimizer.step()
# Finally, reset the buf
self.log_probs_buf = {
"u0": None,
"u": [],
"ws": []
}
......@@ -45,11 +45,18 @@ class EncoderLayer(nn.Module):
k = self.key(x)
v = self.value(x)
x1, _ = self.attention(query=q, key=k, value=v, need_weights=False)
x.add_(x1)
x = self.norm(x.view(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1)
x = x.add(x1)
# x = self.norm(x.reshape(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1)
x = x.transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2)
x1 = self.feed_foward(x)
x.add_(x1)
x = self.norm(x.view(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1)
x = x.add(x1)
# x = self.norm(x.reshape(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1)
x = x.transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2)
return x
class Encoder(nn.Module):
......@@ -58,7 +65,6 @@ class Encoder(nn.Module):
self.dim = dim
self.N = N
self.W_emb = nn.Linear(2, self.dim, bias=False) # get E_0
# TODO: implement batch norm
self.encoder_layers = nn.Sequential()
for i in range(self.N):
self.encoder_layers.add_module('{}_{}'.format(EncoderLayer.__name__, i),
......@@ -72,7 +78,15 @@ class Encoder(nn.Module):
batch_size = x.shape[0]
num_nodes = x.shape[1]
x = self.W_emb(x)
x = self.norm(x.view(-1, x.shape[-1])).reshape(batch_size, num_nodes, -1)
x = x.transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2)
"""
dim = x.shape[-1]
x = x.view(-1, dim)
x = self.norm(x)
x = x.reshape(batch_size, num_nodes, -1)
"""
x = self.encoder_layers(x)
return x
......@@ -103,15 +117,15 @@ class PTM(nn.Module):
# e: [#num_batch, #num_nodes, 360]
# q: [#num_batch, , 360]
# for each batch, perform e + q (q is broadcasted in to #num_nodes)
e.transpose_(0, 1).add_(q).transpose_(0, 1)
e = e.transpose_(0, 1).add(q).transpose(0, 1)
e = self.W_g(e) # get l
e.squeeze_()
e = e.squeeze()
l = e.clone()
if mask != None:
# points be masked is set to be -INF
e = torch.where(mask == False, e, -torch.inf)
e = torch.tanh(e)
e.mul_(self.C)
e = e.mul(self.C)
p = nn.functional.softmax(e, dim=1)
if need_l:
return p, l
......@@ -197,7 +211,7 @@ class EPTM(nn.Module):
if mask.dim() == 2:
mask = torch.stack([mask, mask], dim=1)
l = torch.where(mask == False, l, -torch.inf)
l = torch.tanh(l).mul_(self.C)
l = torch.tanh(l).mul(self.C)
batch_size = l.shape[0]
p4ws = torch.softmax(l.reshape(batch_size, -1), dim=-1)
return p4ws
......
......@@ -42,4 +42,7 @@ for i in range(1, num_nodes):
all_v = torch.stack(v_list).transpose(1, 0)
all_h = torch.stack(h_list).transpose(1, 0)
rewards = -10 * torch.ones(batch_size)
model.learn(nodes, rewards)
pdb.set_trace()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment